### Imports & Setup

In [32]:
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load CLEVR data & Define functions to build vocabs

In [57]:
def build_answer_vocab(questions):
    answers = sorted(set(q["answer"] for q in questions))
    return {a: i for i, a in enumerate(answers)}

def build_attr_vocab(questions):
    attrs = set()
    for q in questions:
        for step in q["program"]:
            if step["function"].startswith("filter_"):
                attrs.add(step["value_inputs"][0])
    return {a:i for i,a in enumerate(sorted(attrs))}

with open("data/CLEVR_v1.0/questions/CLEVR_train_questions.json") as f:
    questions_json = json.load(f)["questions"]

ANSWER_VOCAB = build_answer_vocab(questions_json)
IDX_TO_ANSWER = {v: k for k, v in ANSWER_VOCAB.items()}
NUM_ANSWERS = len(ANSWER_VOCAB)

ATTR_VOCAB = build_attr_vocab(questions_json)
ATTR_VOCAB_SIZE = len(ATTR_VOCAB)

REL_VOCAB = {"left":0, "right":1, "front":2, "behind":3}
REL_VOCAB_SIZE = len(REL_VOCAB)

print("Number of answers:", NUM_ANSWERS)
print("Num attributes:", ATTR_VOCAB_SIZE)

Number of answers: 28
Num attributes: 15


### CLEVR -> NMN

In [101]:
def clevr_to_nmn(program, attr_vocab, rel_vocab=None):
    # If already NMN-style tuples, return unchanged
    if len(program) > 0 and isinstance(program[0], tuple):
        return program

    nmn_program = []
    for step in program:
        f = step["function"]

        if f == "scene":
            nmn_program.append(("scene", None))

        elif f.startswith("filter_"):
            attr = step["value_inputs"][0]
            nmn_program.append(("attend", attr_vocab[attr]))

        elif f.startswith("relate_"):
            # optional: if you want relation types
            rel = f.split("_")[1]
            if rel_vocab:
                nmn_program.append(("relate", rel_vocab[rel]))
            else:
                nmn_program.append(("relate", None))
        
        elif f in ("intersect", "union"):
            nmn_program.append(("combine", None))

        elif f == "unique":
            nmn_program.append(("unique", None))

        elif f == "count":
            nmn_program.append(("measure", "count"))
        elif f == "exist":
            nmn_program.append(("measure", "exist"))

        elif f == "query" or f.startswith("query"):
            # simplify to a measure (only for demo)
            nmn_program.append(("measure", "query"))
        
        elif f == "same" or f.startswith("same_"):
            # simplify to a measure (only for demo)
            nmn_program.append(("measure", "same"))

        # unknown function
        else:
            # treat unknown as attend
            nmn_program.append(("attend", 0))

    return nmn_program

In [90]:
def normalize_program(program):
    normalized = []
    for op, arg in program:
        if op == "classify":
            normalized.append(("measure", arg))
        else:
            normalized.append((op, arg))
    return normalized

### Backbone

In [91]:
backbone = models.resnet18(pretrained=True)
backbone = nn.Sequential(*list(backbone.children())[:-2])
backbone = backbone.to(device)
for p in backbone.parameters():
    p.requires_grad = False

backbone.eval()

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

### Define Modules

In [92]:
class AttendModule(nn.Module):
    def __init__(self, feat_dim, vocab_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, feat_dim)
        self.conv = nn.Conv2d(feat_dim, feat_dim, kernel_size=1)

    def forward(self, features, attr_idx):
        w = self.embed(attr_idx).unsqueeze(-1).unsqueeze(-1)
        x = features * w
        attn = self.conv(x)
        return torch.sigmoid(attn)

class RelateModule(nn.Module):
    def __init__(self, feat_dim, rel_vocab_size):
        super().__init__()
        self.embed = nn.Embedding(rel_vocab_size, feat_dim)
        self.conv = nn.Conv2d(feat_dim, feat_dim, kernel_size=1)

    def forward(self, attn, rel_idx):
        r = self.embed(rel_idx).unsqueeze(-1).unsqueeze(-1)
        x = attn * r
        return self.conv(x)

class CombineModule(nn.Module):
    def forward(self, a1, a2):
        return a1 * a2

class ClassifyModule(nn.Module):
    def __init__(self, feat_dim, num_answers):
        super().__init__()
        self.fc = nn.Linear(feat_dim, num_answers)

    def forward(self, attn, features):
        weighted = (attn * features).sum(dim=[2, 3])
        return self.fc(weighted)
    
class UniqueModule(nn.Module):
    def __init__(self, feat_dim):
        super().__init__()
        self.conv = nn.Conv2d(feat_dim, feat_dim, kernel_size=1)

    def forward(self, attn):
        return self.conv(attn)

class MeasureModule(nn.Module):
    def __init__(self, num_answers, feat_dim=512):
        super().__init__()
        self.fc = nn.Linear(feat_dim, num_answers)

    def forward(self, attn, op):
        pooled = attn.mean(dim=(2,3))
        return self.fc(pooled)


In [109]:
class NMNComposer(nn.Module):
    def __init__(self, module_dict):
        super().__init__()
        self.module_dict = nn.ModuleDict(module_dict)

    def forward(self, features, program):
        B, C, H, W = features.shape
        stack = []

        for op, arg in program:
            if op == "scene":
                stack = []
                stack.append(torch.ones((B,1,H,W), device=features.device))

            elif op == "attend":
                stack.append(self.module_dict["attend"](features, torch.tensor([arg]).to(features.device)))

            elif op == "relate":
                a = stack.pop()
                stack.append(self.module_dict["relate"](a, arg))

            elif op == "unique":
                a = stack.pop()
                stack.append(self.module_dict["unique"](a))

            elif op == "measure":
                a = stack.pop()
                return self.module_dict["measure"](a, features)

        raise RuntimeError("Program ended without measure")

In [94]:
class CLEVRNMNDataset(Dataset):
    def __init__(self, questions_json, image_dir, transform, answer_vocab):
        # Load CLEVR questions
        with open(questions_json) as f:
            self.questions = json.load(f)["questions"]

        self.image_dir = image_dir
        self.transform = transform
        self.answer_vocab = answer_vocab

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        q = self.questions[idx]

        # Load image
        img_path = os.path.join(
            self.image_dir,
            f"CLEVR_train_{q['image_index']:06d}.png"
        )
        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        # Convert CLEVR program to NMN format
        program = clevr_to_nmn(q["program"], ATTR_VOCAB)

        # Convert answer to integer
        answer = self.answer_vocab[q["answer"]]

        return image, program, torch.tensor(answer)

In [111]:
def collate_fn(batch):
    images = torch.stack([b[0] for b in batch])
    programs = [b[1] for b in batch]
    answers = torch.tensor([b[2] for b in batch])
    return images, programs, answers

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset = CLEVRNMNDataset(
    questions_json="data/CLEVR_v1.0/questions/CLEVR_train_questions.json",
    image_dir="data/CLEVR_v1.0/images/train",
    transform=transform,
    answer_vocab=ANSWER_VOCAB
)

subset = torch.utils.data.Subset(dataset, list(range(1000)))

dataloader = DataLoader(
    subset,
    batch_size=4,
    shuffle=True,
    collate_fn=collate_fn
)

In [112]:
q = dataset.questions[0]
print(type(q["program"]))
print(q["program"][:3])

<class 'list'>
[{'inputs': [], 'function': 'scene', 'value_inputs': []}, {'inputs': [0], 'function': 'filter_size', 'value_inputs': ['large']}, {'inputs': [1], 'function': 'filter_color', 'value_inputs': ['green']}]


In [113]:
modules = {
    "attend": AttendModule(512, ATTR_VOCAB_SIZE).to(device),
    "relate": RelateModule(512, REL_VOCAB_SIZE).to(device),
    "combine": CombineModule().to(device),
    "unique": UniqueModule(512).to(device),
    "measure": MeasureModule(NUM_ANSWERS).to(device)
}

batch = next(iter(dataloader))
images, programs, answers = batch

# choose sample 0 from the batch
image = images[0].unsqueeze(0).to(device)
program = programs[0]
answer = answers[0].to(device)

features = backbone(image)

converted_program = clevr_to_nmn(program, ATTR_VOCAB)

composer = NMNComposer(modules).to(device)
logits = composer(features, converted_program)

print("NMN program:", converted_program)
print("Logits shape:", logits.shape)

NMN program: [('scene', None), ('attend', 12), ('attend', 5), ('attend', 8), ('attend', 4), ('measure', 'count')]
Logits shape: torch.Size([1, 28])


### Training Helper Function

In [120]:
def train_epoch(dataloader, backbone, composer, optimizer, device):
    total_loss = 0
    correct = 0
    total = 0

    for images, programs, answers in dataloader:
        images = images.to(device)
        answers = answers.to(device)

        features = backbone(images)

        batch_logits = []
        for i, prog in enumerate(programs):
            converted = clevr_to_nmn(prog, ATTR_VOCAB)
            logits = composer(features[i:i+1], converted)
            batch_logits.append(logits)

        logits = torch.cat(batch_logits, dim=0)
        loss = F.cross_entropy(logits, answers)

        optimizer.zero_grad()
        loss.backward()
        #for name, p in composer.named_parameters():
        #    if p.grad is None:
        #        print("NO GRAD:", name)
        #    else:
        #        print("OK:", name, p.grad.abs().mean().item())
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        pred = logits.argmax(dim=1)
        correct += (pred == answers).sum().item()
        total += images.size(0)

    return total_loss / total, correct / total

In [123]:
optimizer = Adam(
    composer.parameters(),
    lr=3e-4
)

In [125]:
NUM_EPOCHS = 5

for epoch in range(NUM_EPOCHS):
    loss, acc = train_epoch(dataloader, backbone, composer, optimizer, device)
    print(f"Epoch {epoch+1} | Loss: {loss:.4f} | Accuracy: {acc:.4f}")

Epoch 1 | Loss: 2.3325 | Accuracy: 0.2760
Epoch 2 | Loss: 2.2923 | Accuracy: 0.2770
Epoch 3 | Loss: 2.2335 | Accuracy: 0.2820
Epoch 4 | Loss: 2.1988 | Accuracy: 0.2860
Epoch 5 | Loss: 2.1475 | Accuracy: 0.3100
