# Run everything in this section to define functions and imports. 

In [1]:
# Imports. 
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn 
import torch.nn.functional as F
import numpy as np
import json
import csv
from collections import Counter

### Define the paths below for code to run. 

In [2]:
snare_path = '/home/rcorona/obj_part_lang/snare-master/amt/folds_adversarial'
metadata_path = './data/metadata.csv'
categories_path = './data/categories.synset.csv'
lfn_feat_dir = './data/lfn_feats'
clip_feat_dir = '/home/rcorona/dev/snare-master/data/shapenet-clipViT32-frames/'
pixelnerf_feat_dir = '/home/rcorona/2022/lang_nerf/vlg/snare-master/data/pixelnerf_custom_feats'
legoformer_feat_dir = '/home/rcorona/2022/lang_nerf/vlg/snare-master/data/legoformer_multiview_feats/'

# TODO Set GPU number here. 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
def get_snare_objs(): 
    """
    Get all ShapeNetSem object IDs for objects used in SNARE. 
    """
    train = json.load(open(os.path.join(snare_path, 'train.json')))
    val = json.load(open(os.path.join(snare_path, 'val.json')))
    test = json.load(open(os.path.join(snare_path, 'test.json')))

    train_objs = set()
    val_objs = set()
    test_objs = set()

    # Comb through snare files to collect unique set of ShapeNet objects. 
    snare_objs = set()

    for obj_set, split in [(train_objs, train), (val_objs, val), (test_objs, test)]:
        for datapoint in split: 
            for obj in datapoint['objects']:
                obj_set.add(obj)

    all_objs = train_objs | val_objs | test_objs

    return list(all_objs)

In [5]:
def get_plot_labels(objs):
    """
    Generate category labels for TSNE plot. 
    """
    
    # Load all metadata for objects. 
    with open(metadata_path, 'r') as csvfile: 
        metadata = [row for row in csv.reader(csvfile)]
        
    # Get index of each object in metdata. 
    obj2synset = {m[0].replace('wss.', '').strip(): m[2].strip() for m in metadata}
        
    # Load synset word mappings. 
    with open(categories_path, 'r') as csvfile: 
        mappings = [row for row in csv.reader(csvfile)]
        
        # Mapping from synset to word. 
        s2w = {r[2].strip(): r[3].split(',')[0].strip() for r in mappings[1:]}    


    # Get set of all synsets. 
    synset_codes = set([r[2].strip() for r in metadata[2:]])    
    synsets = []

    for s in synset_codes: 
        try: 
            synset = s2w[s]
        except: 
            synset = 'None'
            
        synsets.append(synset)
        

    ## Get 10 most common object categories and filter out everything else.  
    snare_synsets = []
    counts = Counter()

    # Count object categories. 
    for obj in objs:
        synset = obj2synset[obj]

        # Only count those with label. 
        if synset in s2w:
            word = s2w[synset]
            snare_synsets.append(word)
            
            # Update word count.
            if not word == '': 
                counts[word] += 1
        else: 
            snare_synsets.append(None)
            
    return snare_synsets, counts

In [42]:
def load_lfn_features(objs):
    """
    Given list of object IDs, load features for each object under LFN.  
    """
    # Load all features and name order. 
    feats = []
    
    for obj in objs:
        
        # Load feature. 
        path = os.path.join(lfn_feat_dir, '{}.npy'.format(obj))
        feat = feats.append(np.load(path))
        
    return feats

In [6]:
def load_clip_features(objs):
    """
    Given list of object IDs, load features for each object under CLIP.  
    """
    # Load all features and name order. 
    feats = []

    for obj in objs:
        
        # Load image features for object. 
        path = os.path.join(clip_feat_dir, '{}.npy'.format(obj))
        
        # Load features for input views and take mean. 
        feat = np.mean(np.load(path)[6:], axis=0)
        feats.append(feat)
        
    return feats

In [7]:
def load_legoformer_features(objs):
    """
    Given list of object IDs, load features for each object under LegoFormer.  
    """
    # Load all features and name order. 
    feats = []

    for obj in objs:
        
        # Load image features for object. 
        path = os.path.join(legoformer_feat_dir, '{}.npy'.format(obj))
        
        # Load and collapse dimension to consider as single feature. 
        feats.append(np.reshape(np.load(path), -1))
        
    return feats

In [8]:
def load_pixelnerf_features(objs):
    """
    Given list of object IDs, load features for each object under PixelNeRF.  
    """
    # Load all features and name order. 
    feats = []

    print('Loading PixelNeRF features...')

    for obj in tqdm(objs):
        
        # Load image features for object. 
        path = os.path.join(pixelnerf_feat_dir, '{}.npy'.format(obj))
        
        # Load features for input views and take mean. 
        feat = np.reshape(np.load(path), (8, 512, -1))
        feat = np.mean(np.mean(feat, axis=0), axis=-1)
        feats.append(feat)
        
    return feats

In [9]:
def filter_by_top_k(counts, feats, objs, snare_synsets):
    # Get 10 most common categories. 
    top10 = {t[0]: t[1] for t in counts.most_common(10)}

    # Only keep objects in top-10 categories. 
    final_feats = []
    final_labels = []

    assert len(objs) == len(feats) and len(objs) == len(snare_synsets)

    for i in range(len(objs)):
        synset = snare_synsets[i]
        
        if synset in top10: 
            final_feats.append(feats[i])
            final_labels.append(synset)
            
    # Create numpy array of features. 
    final_feats = np.stack(final_feats)
    print('Final feature shape: {}'.format(final_feats.shape))
    
    return final_feats, final_labels

In [10]:
# Generate list of all ShapeNetSem objects used in SNARE. 
snare_objs = get_snare_objs()

# Generate category labels for TSNE plot. 
synsets, counts = get_plot_labels(snare_objs)

# Linear Probe Functions and Classes

In [11]:
# Dataset class for linear probes. 
class ProbeDataset(Dataset):
    
    def __init__(self, feats, labels, idx_dict):
        
        # List of objects in this dataset split. 
        self.feats = feats
        
        # Synset labels for objects. 
        self.labels = labels
        
        # Holds ID mappings for labels. 
        self.idx_dict = idx_dict
        
    def __len__(self):
        return len(self.feats)
    
    def __getitem__(self, idx): 
        
        # Load object features and label. 
        feat = torch.Tensor(self.feats[idx]).float()
        label = self.idx_dict[self.labels[idx]]
        
        return feat, label

In [17]:
class LinearProbe(nn.Module):
    
    def __init__(self, feat_dim, n_categories, mlp=False):
        super().__init__()
        self.feat_dim = feat_dim
        self.n_categories = n_categories
        
        # Simple linear probe. 
        if not mlp: 
            self.probe = nn.Linear(feat_dim, n_categories)
        else: 
            self.probe = nn.Sequential(
                nn.Linear(feat_dim, feat_dim // 2), 
                nn.ReLU(), 
                nn.Linear(feat_dim // 2, n_categories)
            )
        
    def forward(self, x):
        return self.probe(x)

In [13]:
# Training loop. 
def train_loop(model, train_dataloader):
    
    # Initialize optimizer. 
    optim = torch.optim.Adam(model.parameters())
    
    # Place model on train mode. 
    model.train()
    
    # Do one epoch of updates. 
    for feats, labels in train_dataloader:
        
        # Put on GPU. 
        feats = feats.cuda()
        labels = labels.cuda()
        
        # Zero out gradients on optimizer. 
        optim.zero_grad()
        
        # Forward pass. 
        logits = model(feats)
        
        # Compute CE loss and take update.  
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        optim.step()

In [14]:
# Evaluate model accuracy on a dataset split. 
def eval_model(model, loader):
    
    # Keep track of accuracy across all datapoints. 
    all_correct = []
    
    # Place model on eval mode. 
    model.eval()
    
    # Go over entire dataset. 
    with torch.no_grad():
        for feats, labels in loader:
            
            # Put on GPU. 
            feats = feats.cuda()
            labels = labels.cuda()
            
            # Forward pass and prediction. 
            logits = model(feats)
            preds = logits.argmax(dim=1)
            
            # Compute accuracy across batch and add to list over dataset. 
            correct = torch.eq(preds, labels).long()
            all_correct.append(correct.cpu().numpy())
            
    # Compute dataset split accuracy. 
    all_correct = np.concatenate(all_correct)
    acc = np.mean(all_correct)
    
    return acc 

In [15]:
# Training pipeline. 
def run_probe(model, train_loader, val_loader, test_loader, n_epochs):
    
    # Keep track of best model checkpoint for test set. 
    best_acc = 0.0
    
    for i in range(n_epochs):
        
        # Do a training iteration with model. 
        train_loop(model, train_loader)
        
        # Evaluate model on validation set. 
        val_acc = eval_model(model, val_loader)
        
        # Keep best performing model checkpoint. 
        if val_acc > best_acc: 
            torch.save(model.state_dict(), 'probe.pth')
            best_acc = val_acc
        
        # Print best accuracy. 
        print('Best Acc: {}'.format(best_acc))
            
    # Evaluate best checkpoint on test set.
    model.load_state_dict(torch.load('probe.pth')) 
    test_acc = eval_model(model, test_loader)
    print('Probe test performance: {}'.format(test_acc))
    
    # Get rid of probe temp path. 
    os.remove('probe.pth')

In [16]:
def linear_probe(snare_objs, counts, synsets, feat_load_func, feat_dim, mlp=False):

    # General hyperparameters. 
    batch_size = 64
    n_epochs = 100
    n_categories = 10

    # Load features and labels.
    feats = feat_load_func(snare_objs)

    # Filter features by top 10 most occurring categories. 
    final_feats, final_labels = filter_by_top_k(counts, feats, snare_objs, synsets)

    # Compute dataset split lengths. 
    train_len = int(float(len(final_feats)) * 0.8)
    val_len = int(float(len(final_feats)) * 0.1)

    # Compute ID dictionary for labels. 
    label_list = list(set(final_labels))
    idx_dict = {label_list[i]: i for i in range(len(label_list))} 

    # Split into datasets. 
    train_dataset = ProbeDataset(final_feats[:train_len], final_labels[:train_len], idx_dict)
    val_dataset = ProbeDataset(final_feats[train_len:train_len+val_len], final_labels[train_len:train_len+val_len], idx_dict)
    test_dataset = ProbeDataset(final_feats[train_len+val_len:], final_labels[train_len+val_len:], idx_dict)

    # Form dataloaders. 
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    # Instantiate model. 
    model = LinearProbe(feat_dim, 10, mlp=mlp)
    model.cuda()
    
    # Run linear probe. 
    run_probe(model, train_loader, val_loader, test_loader, n_epochs)

# CLIP Linear Probe

In [54]:
print('Training CLIP Linear Probe...')
linear_probe(snare_objs, counts, synsets, load_clip_features, 512)

Training CLIP Linear Probe...
Final feature shape: (2273, 512)
Best Acc: 0.6784140969162996
Best Acc: 0.8061674008810573
Best Acc: 0.8281938325991189
Best Acc: 0.8546255506607929
Best Acc: 0.8590308370044053
Best Acc: 0.8722466960352423
Best Acc: 0.8722466960352423
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8766519823788547
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc: 0.8810572687224669
Best Acc:

# VLG (LegoFormer) Linear Probe

In [55]:
print('Training LegoFormer Linear Probe...')
linear_probe(snare_objs, counts, synsets, load_legoformer_features, 96 * 12)

Training LegoFormer Linear Probe...
Final feature shape: (2273, 1152)
Best Acc: 0.4977973568281938
Best Acc: 0.5638766519823789
Best Acc: 0.5638766519823789
Best Acc: 0.6123348017621145
Best Acc: 0.6387665198237885
Best Acc: 0.6387665198237885
Best Acc: 0.6431718061674009
Best Acc: 0.6696035242290749
Best Acc: 0.6696035242290749
Best Acc: 0.6784140969162996
Best Acc: 0.6784140969162996
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7004405286343612
Best Acc: 0.7048458149779736
Best Acc: 0.7048458149779736
Best Acc: 0.7048458149779736
Best Acc: 0.7048458149779736
Best Acc: 0.7048458149779736
Best Acc: 0.7092511013215859
Best Acc: 0.7092511013215859
Best Acc: 0.7092511013215859
Best Acc: 0.7180616740088106
Best Acc: 0.7180616740088106
Best Acc: 0.7180616740088106
Best Acc: 0.7180616740088106
Be

# PixelNeRF Linear Probe

In [56]:
print('Training PixelNeRF Linear Probe...')
linear_probe(snare_objs, counts, synsets, load_pixelnerf_features, 512)

Training PixelNeRF Linear Probe...
Loading PixelNeRF features...


100%|██████████| 7881/7881 [03:48<00:00, 34.54it/s]


Final feature shape: (2273, 512)
Best Acc: 0.29515418502202645
Best Acc: 0.3436123348017621
Best Acc: 0.4669603524229075
Best Acc: 0.47577092511013214
Best Acc: 0.47577092511013214
Best Acc: 0.47577092511013214
Best Acc: 0.5594713656387665
Best Acc: 0.5594713656387665
Best Acc: 0.5594713656387665
Best Acc: 0.5594713656387665
Best Acc: 0.5814977973568282
Best Acc: 0.6299559471365639
Best Acc: 0.6916299559471366
Best Acc: 0.6916299559471366
Best Acc: 0.6916299559471366
Best Acc: 0.7004405286343612
Best Acc: 0.7092511013215859
Best Acc: 0.7092511013215859
Best Acc: 0.7400881057268722
Best Acc: 0.7709251101321586
Best Acc: 0.7709251101321586
Best Acc: 0.7709251101321586
Best Acc: 0.7797356828193832
Best Acc: 0.7797356828193832
Best Acc: 0.7797356828193832
Best Acc: 0.7797356828193832
Best Acc: 0.7797356828193832
Best Acc: 0.7797356828193832
Best Acc: 0.7797356828193832
Best Acc: 0.7841409691629956
Best Acc: 0.7841409691629956
Best Acc: 0.7841409691629956
Best Acc: 0.7841409691629956
Best A

# LFN Linear Probe

In [57]:
print('Training LFN Linear Probe...')
linear_probe(snare_objs, counts, synsets, load_lfn_features, 256, mlp=True)

Training LFN Linear Probe...
Final feature shape: (2273, 256)
Best Acc: 0.15859030837004406
Best Acc: 0.15859030837004406
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.18502202643171806
Best Acc: 0.2026431718061674
Best Acc: 0.2026431718061674
Best Acc: 0.21145374449339208
Best Acc: 0.21145374449339208
Best Acc: 0.21145374449339208
Best Acc: 0.21145374449339208
Best Acc: 0.21145374449339208
Best Acc: 0.21145374449339208
Best Acc: 0.21585903083700442
Best Acc: 0.21585903083700442
Best Acc: 0.22026431718061673
Best Acc: 0.22026431718061673
Best Acc: 0.22026431718061673
Best Acc: 0.22026431718061673
Best Acc: 0.22026431718061673
Best Acc: 0.22026431718061673
Best Acc: 0.22026431718061673
Best Acc: 0.22466960352422907
Best Acc: 0.22466960352422907
Best Acc: 0.22466960352422907
Best Acc: 0.22466960352422907
Best Acc: 