In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision.transforms as transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
from typing import Any, Dict, Optional, Tuple, Union, List
import os, json, random
from collections import Counter, defaultdict
import numpy as np
from tqdm import tqdm, trange

In [2]:
# To ensure we get reproducible results
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [33]:
randaug = transforms.RandomApply(transforms=[
    #transforms.RandomRotation(20, interpolation=Image.BILINEAR),
    transforms.ColorJitter(brightness=.3, contrast=.3, saturation=0.3, hue=0),
    #transforms.RandomPerspective(distortion_scale=0.8, p=0.6),
], p=0.5)

In [34]:
class dataset(Dataset):
    def __init__(self,
                 classes: List[str], # used to define class_ids, order matters
                 imdir: str,
                 data: List,
                 imsize = (64, 64),
                 ):
        super().__init__()
        self.imdir = imdir
        self.data = data
        self.preprocess = transforms.Compose(
            [   
                transforms.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
                randaug,
                transforms.RandomResizedCrop(imsize, scale=(0.8, 1.0), ratio=(0.5, 1.5)),
                #transforms.Resize(imsize),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ]
        )
        self.classes = classes
        #self.n2i = {n:i+1 for i, n in enumerate(classes)}

    def __len__(self): return len(self.data)
    
    def __getitem__(self, i): 
        text, image_path, tuple = self.data[i]
        image = Image.open(os.path.join(self.imdir, image_path))
        subj, obj, r = tuple
        label = torch.zeros((18, ), dtype=torch.int64)
        for i in range(len(self.classes)):
            if self.classes[i] in [subj, obj]:
                label[i] = 1

        return {
            'sentence': text,
            'image': self.preprocess(image),
            'label': label
        }
    


In [5]:
def get_acc(output, gth, num_classes):
    # output: (bs*num_classes, 2)
    # gth: (bs*num_classes,)
    pred = output.argmax(1, keepdim=True).reshape((-1, num_classes))
    gth = gth.reshape((-1, num_classes))
    correct = sum(pred.eq(gth).sum(axis=1) == num_classes)
    acc = correct.float() / gth.shape[0]
    return acc

In [6]:
def train(net, dataloader, optimizer, criterion, num_classes, device):
    net.train()
    running_loss, running_acc = [], []
    for batch in dataloader:
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        bs = len(batch['image'])
        outputs = net(batch['image'].to(device)).reshape((bs*num_classes, -1)) 

        labels = batch['label'].flatten().to(device) # (bs*num_classes, )

        loss = criterion(outputs, labels)
        acc = get_acc(outputs, labels, num_classes)

        loss.backward()
        optimizer.step()

        # print statistics
        running_loss.append(loss.item())
        running_acc.append(acc.item())
    return np.mean(running_loss), np.mean(running_acc)

def val(net, dataloader, criterion, num_classes, device):
    net.eval()
    epoch_loss, epoch_acc = [], []

    with torch.no_grad():
        for batch in dataloader:
            bs = len(batch['image'])
            outputs = net(batch['image'].to(device)).reshape((bs*num_classes, -1)) 

            labels = batch['label'].flatten().to(device)

            loss = criterion(outputs, labels)
            acc = get_acc(outputs, labels, num_classes)

            # print statistics
            epoch_loss.append(loss.item())
            epoch_acc.append(acc.item())
        
    return np.mean(epoch_loss), np.mean(epoch_acc)

In [35]:
annotations = json.load(open("../data/aggregated/whatsup_vlm_b.json", "r"))
occurrences = [a[-1][0] for a in annotations] + [a[-1][1] for a in annotations]

c = Counter(occurrences)
classes = sorted(c.keys(), key=lambda x: (-c[x], x))
D = dataset(
    classes, 
    imdir = "/data/yingshac/clevr_control/data/",
    data = annotations
)

device="cuda:0"

train_ratio = 0.8
train_num = int(train_ratio*len(D))
val_num = len(D) - train_num
train_data, val_data = random_split(D, [train_num, val_num])
print(f"{len(train_data)} training examples, {len(val_data)} testing examples")

num_classes = len(D.classes)

326 training examples, 82 testing examples


In [36]:

class Net(nn.Module):
    def __init__(self,
                 num_classes: int,
                 ):
        super().__init__()
        self.classifiers = nn.ModuleList([nn.Sequential(
            nn.Conv2d(3, 64, 5),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), 
            nn.Conv2d(64, 128, 3), 
            nn.ReLU(),
            nn.MaxPool2d(2, 2), 
            nn.Conv2d(128, 128, 3),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 8, 6), # (bs, 8, 1, 1)
            nn.Flatten(1), # (bs, 8)
            nn.Linear(8, 2, bias=False) # (bs, 2)
        ) for i in range(num_classes)])

    def forward(self, x):
        res = []
        for c in self.classifiers:
            #y = x
            #for layer in c:
                #y = layer(y)
            res.append(c(x)) # (bs, 2)
        res = torch.stack(res, axis=0).transpose(0, 1) # (bs, num_classes, 2)
        return res


net = Net(len(D.classes)).to(device) # +1 for black_id

In [37]:
batch_size = 16
trainloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
testloader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1).to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001) #, momentum=0.9)

In [38]:
best_val_loss = float('inf')
best_val_acc = 0.0
for epc in range(100):

    train_loss, train_acc = train(net, trainloader, optimizer, criterion, num_classes, device)
    val_loss, val_acc = val(net, testloader, criterion, num_classes, device)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        #torch.save(model.state_dict(), os.path.join(config["save_dir"], config["date"], "model.pt"))
    
    print(f'Epoch: {epc+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {val_loss:.3f} |  Val. Acc: {val_acc*100:.2f}%')
    
print("Training: finish\n")



Epoch: 01
	Train Loss: 0.483 | Train Acc: 0.00%
	 Val. Loss: 0.421 |  Val. Acc: 0.00%
Epoch: 02
	Train Loss: 0.407 | Train Acc: 0.30%
	 Val. Loss: 0.398 |  Val. Acc: 0.00%
Epoch: 03
	Train Loss: 0.383 | Train Acc: 1.19%
	 Val. Loss: 0.380 |  Val. Acc: 1.04%
Epoch: 04
	Train Loss: 0.363 | Train Acc: 1.69%
	 Val. Loss: 0.371 |  Val. Acc: 2.08%
Epoch: 05
	Train Loss: 0.351 | Train Acc: 6.55%
	 Val. Loss: 0.348 |  Val. Acc: 7.29%
Epoch: 06
	Train Loss: 0.327 | Train Acc: 11.31%
	 Val. Loss: 0.341 |  Val. Acc: 9.38%
Epoch: 07
	Train Loss: 0.318 | Train Acc: 12.50%
	 Val. Loss: 0.327 |  Val. Acc: 8.33%
Epoch: 08
	Train Loss: 0.307 | Train Acc: 18.85%
	 Val. Loss: 0.316 |  Val. Acc: 21.88%
Epoch: 09
	Train Loss: 0.299 | Train Acc: 26.39%
	 Val. Loss: 0.315 |  Val. Acc: 15.62%
Epoch: 10
	Train Loss: 0.291 | Train Acc: 28.17%
	 Val. Loss: 0.312 |  Val. Acc: 25.00%
Epoch: 11
	Train Loss: 0.285 | Train Acc: 32.44%
	 Val. Loss: 0.316 |  Val. Acc: 17.71%
Epoch: 12
	Train Loss: 0.276 | Train Acc: 40