In [1]:
import torch
from pathlib import Path
from torchvision import datasets

from einops import rearrange

In [2]:
import sys

PATH = sys.path
newPATH = ['/rcfs/projects/task0_pmml/TRAKfork/trak',] + PATH
sys.path = newPATH

In [4]:
#We are using 1 model because we have only 1 model?
ckpts = [torch.load('/rcfs/projects/task0_pmml/MODELS/resnet18.pt'),]

In [5]:
from trak import TRAKer
from trak.modelout_functions import trNTKModelOutput 

In [6]:
#saver = savers.Mmapsaver()

In [7]:
import torch
import torch.nn as nn
import os

__all__ = [
    "ResNet",
    "resnet18",
    "resnet34",
    "resnet50",
]


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=dilation,
        groups=groups,
        bias=False,
        dilation=dilation,
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(BasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if groups != 1 or base_width != 64:
            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = norm_layer(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(
        self,
        inplanes,
        planes,
        stride=1,
        downsample=None,
        groups=1,
        base_width=64,
        dilation=1,
        norm_layer=None,
    ):
        super(Bottleneck, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    def __init__(
        self,
        block,
        layers,
        num_classes=10,
        zero_init_residual=False,
        groups=1,
        width_per_group=64,
        replace_stride_with_dilation=None,
        norm_layer=None,
    ):
        super(ResNet, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self._norm_layer = norm_layer

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError(
                "replace_stride_with_dilation should be None "
                "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
            )
        self.groups = groups
        self.base_width = width_per_group

        # CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
        )
        # END

        self.bn1 = norm_layer(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(
            block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
        )
        self.layer3 = self._make_layer(
            block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
        )
        self.layer4 = self._make_layer(
            block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(
            block(
                self.inplanes,
                planes,
                stride,
                downsample,
                self.groups,
                self.base_width,
                previous_dilation,
                norm_layer,
            )
        )
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(
                    self.inplanes,
                    planes,
                    groups=self.groups,
                    base_width=self.base_width,
                    dilation=self.dilation,
                    norm_layer=norm_layer,
                )
            )

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        z = x.reshape(x.size(0), -1)
        x = self.fc(z)

        return x
    
    def forward_cam_conv1(self, x):
        y = self.conv1(x)
        x = self.bn1(y)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        z = x.reshape(x.size(0), -1)
        x = self.fc(z)

        return x, y
    
    def forward_conv1(self, x):
        x = self.conv1(x)
        return x
    def forward_bn1(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        return x
    def forward_layer1(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        return x
    def forward_layer2(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        return x
    def forward_layer3(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x
    def forward_layer4(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x
    def forward_flat(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.size(0), -1)
        return x
        


def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
    model = ResNet(block, layers, **kwargs)
    if pretrained:
        script_dir = os.path.dirname(__file__)
        state_dict = torch.load(
            script_dir + "/state_dicts/" + arch + ".pt", map_location=device
        )
        model.load_state_dict(state_dict)
    return model


def resnet18(pretrained=False, progress=True, device="cpu", **kwargs):
    """Constructs a ResNet-18 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet(
        "resnet18", BasicBlock, [2, 2, 2, 2], pretrained, progress, device, **kwargs
    )

def resnet34(pretrained=False, progress=True, device="cpu", **kwargs):
    """Constructs a ResNet-34 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet(
        "resnet34", BasicBlock, [3, 4, 6, 3], pretrained, progress, device, **kwargs
    )


def resnet50(pretrained=False, progress=True, device="cpu", **kwargs):
    """Constructs a ResNet-50 model.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
        progress (bool): If True, displays a progress bar of the download to stderr
    """
    return _resnet(
        "resnet50", Bottleneck, [3, 4, 6, 3], pretrained, progress, device, **kwargs
    )

In [8]:
MODELNAME = 'ResNet18'
model = resnet18(device='cuda').to(memory_format=torch.channels_last).cuda().eval()

In [9]:
import trak

In [19]:
traker = TRAKer(model=model,
                task='trNTK',
                save_dir = '/rcfs/projects/task0_pmml/proj_trNTK/test_blind/',
                train_set_size=60_000,
                num_classes=10,
                proj_dim=512*20,
                proj_max_batch_size=16,
                use_half_precision=False)

                             Report any issues at https://github.com/MadryLab/trak/issues
INFO:STORE:Existing model IDs in /rcfs/projects/task0_pmml/proj_trNTK/test_blind: [0]
INFO:STORE:No model IDs in /rcfs/projects/task0_pmml/proj_trNTK/test_blind have been finalized.
INFO:STORE:No existing TRAK scores in /rcfs/projects/task0_pmml/proj_trNTK/test_blind.


In [20]:
from einops import rearrange

In [21]:
from utils2 import process_Cifar10
train, test, combined = process_Cifar10('./')

In [22]:
train_data = datasets.CIFAR10(
    root = '/people/enge625/NOTEBOOKS/',
    train = True,                          
    download = False,            
)


test_data = datasets.CIFAR10(
    root = '/people/enge625/NOTEBOOKS/', 
    train = False, 
    download=False,
)


train_x = torch.tensor(train_data.data)
test_x = torch.tensor(test_data.data)

train_y = torch.tensor(train_data.targets)
test_y = torch.tensor(test_data.targets)

train_x_plot = torch.tensor(train_data.data).cpu().numpy()
test_x_plot = torch.tensor(test_data.data).cpu().numpy()

train_y_plot = torch.tensor(train_data.targets).cpu().numpy()
test_y_plot = torch.tensor(test_data.targets).cpu().numpy()

train_x = train_x/255
test_x = test_x/255

mean = torch.tensor([0.4914, 0.4822, 0.4465])
std = torch.tensor([0.2471, 0.2435, 0.2616])

test_x -= mean[None,None,None,:]
test_x /= std[None,None,None,:]

train_x -= mean[None,None,None,:]
train_x /= std[None,None,None,:]

train_y = train_y.float()
test_y = test_y.float()

train_x = train_x.to('cuda')
train_y = train_y.to('cuda')

test_x = test_x.to('cuda')
test_y = test_y.to('cuda')

train_x = rearrange(train_x,'b w h f -> b f w h')
test_x = rearrange(test_x,'b w h f -> b f w h')

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

train_loader = DataLoader(train,batch_size=64,shuffle=False)
test_loader = DataLoader(test,batch_size=64,shuffle=False)
test_loader2 = DataLoader(test,batch_size=150,shuffle=False)
combined_loader = DataLoader(combined,batch_size=64,shuffle=False)

In [23]:
combined_loader.dataset.__len__()

60000

In [24]:
import trak

In [25]:
from tqdm import tqdm

for model_id, ckpt in enumerate(ckpts):
    # TRAKer loads the provided checkpoint and also associates
    # the provided (unique) model_id with the checkpoint.
    traker.load_checkpoint(ckpt, model_id=model_id)

    for batch in tqdm(combined_loader):
        batch = [x.cuda() for x in batch]
        batch[1] = torch.tensor(batch[1],dtype=torch.long)
        # TRAKer computes features corresponding to the batch of examples,
        # using the checkpoint loaded above.
        traker.featurize(batch=batch, num_samples=batch[0].shape[0])

# Tells TRAKer that we've given it all the information, at which point
# TRAKer does some post-processing to get ready for the next step
# (scoring target examples).
#traker.finalize_features()

  batch[1] = torch.tensor(batch[1],dtype=torch.long)
  6%|████▊                                                                           | 56/938 [04:30<1:10:55,  4.82s/it]


KeyboardInterrupt: 

In [15]:
import numpy as np
for c in range(traker.num_classes):
    np.save(f'/rcfs/projects/task0_pmml/proj_trNTK/resnet18_trNTK/largegrads_{c}.npy',traker.saver.current_store[f'grads_{c}'])

# run above

In [None]:
import numpy as np

In [None]:
grad_savepath = traker.save_dir.joinpath('0/grads.mmap')

In [None]:
A = torch.from_numpy(np.load(grad_savepath)).to('cuda')

In [None]:
A.shape

In [None]:
proj_NTK = torch.matmul(A,A.T).cpu().numpy()

In [None]:
proj_NTK_savepath = traker.save_dir.joinpath('0/proj_pntk.npy')
np.save(proj_NTK_savepath,proj_NTK)

In [None]:
proj_NTK.shape

# This should save to store:


# Below, kept for posterity, my old wrong way of computing the attributions. Instead, we load the features from the saver_store of Trak, compute the NTK by matrix multiplying the features, and end up with an object that is like all the other NTK objects I analyze

In [None]:
raise Exception

In [None]:
for model_id, checkpoint in enumerate(ckpts):
    traker.start_scoring_checkpoint(checkpoint, model_id=model_id, num_targets=len(test))
    for batch in test_loader:
        batch[1] = torch.tensor(batch[1],dtype=torch.long)
        traker.score(batch=batch,num_samples=batch[0].shape[0])
        


In [None]:
scores = traker.finalize_scores()

In [None]:
for model_id, checkpoint in enumerate(ckpts):
    traker.start_scoring_checkpoint(checkpoint, model_id=model_id, num_targets=len(train))
    for batch in train_loader:
        batch[1] = torch.tensor(batch[1],dtype=torch.long)
        traker.score(batch=batch,num_samples=batch[0].shape[0])
        
scores_train = traker.finalize_scores()

In [None]:
scores_test = scores

In [None]:
scores_test.shape

In [None]:
g = torch.sum(scores_test,axis=0)
#this should be outputting line 19, log(p/(1-p)) ; which is the logit of the correct class

#which IS what I'm doing.

In [None]:
g.shape

In [None]:
torch.save(g,'./trak_results/scores/CIFAR10_test_logits_of_correct_class.pt')