In [1]:
from __future__ import print_function
import argparse

import torch
import torch.nn as nn
import torch.optim as optim
import datasets
from utils.misc import *
from utils.test_helpers import *
from utils.prepare_dataset import *

# ----------------------------------

import copy
import time
import pandas as pd

import random
import numpy as np
import torch.backends.cudnn as cudnn

from discrepancy import *
from offline import *
from utils.trick_helpers import *
from utils.contrastive import *

from online import FeatureQueue




In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='civil')
parser.add_argument('--dataroot', default='.')
parser.add_argument('--shared', default=None)
########################################################################
parser.add_argument('--depth', default=26, type=int)
parser.add_argument('--width', default=1, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--batch_size_align', default=512, type=int)
parser.add_argument('--queue_size', default=256, type=int)
parser.add_argument('--group_norm', default=0, type=int)
parser.add_argument('--workers', default=0, type=int)
parser.add_argument('--num_sample', default=1000000, type=int)
########################################################################
parser.add_argument('--lr', default=0.001, type=float)
parser.add_argument('--nepoch', default=100, type=int, help='maximum number of epoch for ttt')
parser.add_argument('--bnepoch', default=2, type=int, help='first few epochs to update bn stat')
parser.add_argument('--delayepoch', default=0, type=int)
parser.add_argument('--stopepoch', default=25, type=int)
########################################################################
parser.add_argument('--outf', default='.')
########################################################################
parser.add_argument('--level', default=5, type=int)
parser.add_argument('--corruption', default='snow')
parser.add_argument('--resume', default=None, help='directory of pretrained model')
parser.add_argument('--ckpt', default=None, type=int)
parser.add_argument('--fix_ssh', action='store_true')
########################################################################
parser.add_argument('--method', default='ssl', choices=['ssl', 'align', 'both'])
parser.add_argument('--divergence', default='all', choices=['all', 'coral', 'mmd'])
parser.add_argument('--scale_ext', default=0.5, type=float, help='scale of align loss on ext')
parser.add_argument('--scale_ssh', default=0.2, type=float, help='scale of align loss on ssh')
########################################################################
parser.add_argument('--ssl', default='contrastive', help='self-supervised task')
parser.add_argument('--temperature', default=0.5, type=float)
########################################################################
parser.add_argument('--align_ext', action='store_true')
parser.add_argument('--align_ssh', action='store_true')
########################################################################
parser.add_argument('--model', default='bert', help='bert')
parser.add_argument('--save_every', default=100, type=int)
########################################################################
parser.add_argument('--tsne', action='store_true')
########################################################################
parser.add_argument('--seed', default=0, type=int)


_StoreAction(option_strings=['--seed'], dest='seed', nargs=None, const=None, default=0, type=<class 'int'>, choices=None, help=None, metavar=None)

In [3]:
class args:
    dataset='civil'
    dataroot='.'
    shared=None
    depth=26
    width=1
    batch_size=128
    batch_size_align=512
    queue_size=256
    group_norm=0
    workers=0
    num_sample=1000
    lr=0.001
    nepoch=100
    bnepoch=2
    delayepoch=0
    stopepoch=25
    outf='.'
    level=5
    corruption='snow'
    resume=None
    ckpt=None
    fix_ssh=False
    method='ssl'
    divergence='all'
    scale_ext=0.5
    scale_ssh=0.2
    ssl='contrastive'
    temperature=0.5
    align_ext=False
    align_ssh=False
    model='bert'
    save_every=100
    tsne=False
    seed=0

In [4]:
net, ext, head, ssh, classifier = build_bert(args, "bert")

In [5]:
net

ExtractorHead(
  (ext): BertFeaturizer(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=

In [6]:
device = torch.device("cpu")

In [7]:
modelC = getattr(datasets, "civil")
# teloader, tv_loaders = modelC.getDataLoaders(args, device=device)

In [9]:
args.batch_size = min(args.batch_size, args.num_sample)
args.batch_size_align = min(args.batch_size_align, args.num_sample)

args_align = copy.deepcopy(args)
args_align.ssl = None
args_align.batch_size = args.batch_size_align

In [10]:
trloader, tv_loaders = modelC.getDataLoaders(args, device=device)
trloader_extra, teloader = tv_loaders['val'], tv_loaders['test']
model = modelC(args, weights=None).to(device)

<wilds.datasets.civilcomments_dataset.CivilCommentsDataset object at 0x121e77640>


In [16]:
offlineloader = modelC.getOfflineDataLoaders(args_align, device=device)

<wilds.datasets.civilcomments_dataset.CivilCommentsDataset object at 0x121db6c70>


In [15]:
def offline(trloader, ext, scale):
    ext.eval()

    mu_src = None
    cov_src = None

    coral_stack = []
    mmd_stack = []
    feat_stack = []

    with torch.no_grad():
        for batch_idx, (inputs, labels, meta) in enumerate(trloader):
            print(inputs.shape)
            feat = ext(inputs.cuda())
            cov = covariance(feat)
            mu = feat.mean(dim=0)

            if cov_src is None:
                cov_src = cov
                mu_src = mu
            else:
                loss_coral = coral(cov_src, cov)
                loss_mmd = linear_mmd(mu_src, mu)
                coral_stack.append(loss_coral.item())
                mmd_stack.append(loss_mmd.item())
                feat_stack.append(feat)

    print("Source loss_mean: mu = {:.4f}, std = {:.4f}".format(scale, scale / statistics.mean(mmd_stack) * statistics.stdev(mmd_stack)))
    print("Source loss_coral: mu = {:.4f}, std = {:.4f}".format(scale, scale / statistics.mean(coral_stack) * statistics.stdev(coral_stack)))

    feat_all = torch.cat(feat_stack)
    feat_cov = covariance(feat_all)
    feat_mean = feat_all.mean(dim=0)
    return feat_cov, statistics.mean(coral_stack), feat_mean, statistics.mean(mmd_stack)

MMD_SCALE_FACTOR = 0.5
#if args.align_ext:
args_align.scale = args.scale_ext
cov_src_ext, coral_src_ext, mu_src_ext, mmd_src_ext = offline(offlineloader, ext, args.scale_ext)
scale_coral_ext = args.scale_ext / coral_src_ext
scale_mmd_ext = args.scale_ext / mmd_src_ext * MMD_SCALE_FACTOR

# construct queue
if args.queue_size > args.batch_size_align:
    queue_ext = FeatureQueue(dim=mu_src_ext.shape[0], length=args.queue_size-args.batch_size_align)

if args.align_ssh:
    args_align.scale = args.scale_ssh
    from models.SSHead import ExtractorHead
    cov_src_ssh, coral_src_ssh, mu_src_ssh, mmd_src_ssh = offline(offlineloader, ExtractorHead(ext, head).cuda(), args.scale_ssh)
    scale_align_ssh = args.scale_ssh / coral_src_ssh
    scale_mmd_ssh = args.scale_ssh / mmd_src_ssh * MMD_SCALE_FACTOR

    if args.queue_size > args.batch_size_align:
        queue_ssh = FeatureQueue(dim=mu_src_ssh.shape[0], length=args.queue_size-args.batch_size_align)


torch.Size([512, 300, 3])
torch.Size([512, 300, 3])


KeyboardInterrupt: 

In [17]:
ext

BertFeaturizer(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
     

In [None]:
import numpy as np

In [None]:
def test(dataloader, model, sslabel=None):
    criterion = nn.CrossEntropyLoss(reduction='none').cuda()
    model.eval()
    correct = []
    losses = []
    for batch_idx, (inputs, labels, meta) in enumerate(dataloader):
        #input: torch.Size([128, 300, 3])
        print(inputs.shape)
#         if sslabel is not None:
#             inputs, labels = rotate_batch(inputs, sslabel)
        inputs, labels = inputs.cuda(), labels.cuda()
        with torch.no_grad():
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            losses.append(loss.cuda())
            _, predicted = outputs.max(1)
            correct.append(predicted.eq(labels).cuda())
    correct = torch.cat(correct).numpy()
    losses = torch.cat(losses).numpy()
    model.train()
    return 1-correct.mean(), correct, losses
err_cls = test(teloader, net)[0]