In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange, tqdm
import wandb
import argparse


from main import get_model
import agent_net
from visda17 import get_visda_dataloaders, downsample_dataset
from utils import AverageMeter
from gumbel_softmax import *

# load the resnet26 model (optionally with IN weights)
# make a training dataset with synthetic VisDA data
# make a validation dataset with real ViSDA test
# do spottune training (or optionally turn it off)

def train_epoch(train_loader, net, agent=None):
    net.train()

    total_step = len(train_loader)
    train_acc = AverageMeter()
    train_loss = AverageMeter()

    for batch_idx, batch in enumerate(tqdm(train_loader)):
        images, labels = batch   

        if torch.cuda.is_available():
            images, labels = images.cuda(non_blocking=True), labels.cuda(non_blocking=True)  

        if agent:
            probs = agent(images)
            action = gumbel_softmax(probs.view(probs.size(0), -1, 2))
            policy = action[:,:,1]
        else:
            policy = None

        outputs = net.forward(images, policy)
        _, predicted = torch.max(outputs.data, 1)
        correct = predicted.eq(labels.data).cpu().sum()
        train_acc.update(correct.item()*100 / (labels.size(0)+0.0), labels.size(0))

        # Loss
        criterion = nn.CrossEntropyLoss()
        loss = criterion(outputs, labels)
        train_loss.update(loss.item(), labels.size(0))

        if batch_idx % args.log_interval == 0:
            wandb.log({'train_acc': train_acc.val, 'train_loss': train_loss.val})

        # Backward and optimize
        optimizer.zero_grad()
        if agent:
            agent_optimizer.zero_grad()

        loss.backward()  

        optimizer.step()
        if agent:
            agent_optimizer.step()

def validate_epoch(val_loader, net, agent=None):
    net.eval()

    val_acc = AverageMeter()

    for i, batch in enumerate(tqdm(val_loader)):
        images, labels = batch

        if torch.cuda.is_available():
            images, labels = images.cuda(non_blocking=True), labels.cuda(non_blocking=True)  

        outputs = net.forward(images)
        _, predicted = torch.max(outputs.data, 1)
        correct = predicted.eq(labels.data).cpu().sum()
        val_acc.update(correct.item()*100 / (labels.size(0)+0.0), labels.size(0))


    wandb.log({'val_acc': val_acc.avg})

def setup_network(net):
    # freeze the original blocks
    flag = True
    for name, m in net.named_modules():
        if isinstance(m, nn.Conv2d) and 'parallel_blocks' not in name:
            if flag is True:
                flag = False
            else:
                m.weight.requires_grad = False

    # Display info about frozen conv layers
    conv_layers_finetune = [x[0] for x in net.named_modules() if isinstance(x[1], nn.Conv2d) and x[1].weight.requires_grad]
    conv_layers_frozen = [x[0] for x in net.named_modules() if isinstance(x[1], nn.Conv2d) and not x[1].weight.requires_grad]

    print(f"Finetuning ({len(conv_layers_finetune)}) conv layers:")
    print(conv_layers_finetune)

    print(f"Freezing ({len(conv_layers_frozen)}) conv layers:")
    print(conv_layers_frozen)


### SETTINGS ###

class ConfigObject:
    def __init__(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

    def to_dict(self):
        return {key: getattr(self, key) for key in dir(self) if not key.startswith("__")}

# Example usage:
args = ConfigObject(
    initialization = 'pretrained_models/resnet26_pretrained.t7',
    n_epochs = 110,
    lr = 0.1,
    lr_agent = 0.01,
    batch_size_train = 128,
    batch_size_test = 256,
    spottune = False,
    train_fraction = 0.05,
    log_interval = 10,
    gpu_idx = 0
)

wandb.init(
    project='targeted-generalization',
    config=args.to_dict()
)

n_classes = 12 # visda
net = get_model("resnet26", n_classes, dataset=args.initialization)
setup_network(net)
net = net.cuda()

train_loader, val_loader, test_loader = get_visda_dataloaders(
    train_dir='/export/r32/data/visda17/train',
    val_dir='/export/r32/data/visda17/test',
    test_dir='/export/r32/data/visda17/test',
    batch_size_train=args.batch_size_train,
    batch_size_test=args.batch_size_test,
    train_fraction=args.train_fraction
)

optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.n_epochs)

if args.spottune:
    agent = agent_net.resnet(sum(net.layer_config) * 2)
    agent_optimizer = optim.SGD(agent.parameters(), 
        lr= args.lr_agent,
        momentum=0.9, 
        weight_decay=0.001) 
    agent_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        agent_optimizer,
        T_max=args.n_epochs)
else:
    agent = None
    agent_optimizer = None
    agent_scheduler = None

validate_epoch(val_loader, net, agent)
for epoch in trange(args.n_epochs):
    train_epoch(train_loader, net, agent)
    validate_epoch(val_loader, net, agent)

    wandb.log({'net_lr': scheduler.get_last_lr()})
    scheduler.step()

    if args.spottune:
        wandb.log({'agent_lr': agent_scheduler.get_last_lr()})
        agent_scheduler.step()