In [None]:
%%capture

import os
import glob
import jax
import torch
import optax
import wandb 
import shutil
import functools
import haiku as hk
import numpy as np
import pandas as pd
from PIL import Image
import jax.numpy as jnp
import plotly.express as px
from functools import partial
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import transforms, models
from pydash import flatten_deep as flatten, unzip
from torch.utils.data import TensorDataset, DataLoader, SubsetRandomSampler

In [None]:
# Define dataset of unmasked faces

class UnmaskedFaces():
    def __init__(self, mode, transform=[lambda x: x], splits=[0.8, 0.1]):
        super().__init__()
        np.random.seed(42)
        
        self.mode = mode
        self.splits = splits
        
        self.transform = transforms.Compose([
            Image.open, 
            lambda x: x.resize((150,150)), 
            transforms.ToTensor(),
            *transform
        ])
        
        self.label_mapping = {i[1] : i[0] for i in enumerate(os.listdir("LFW"))}

    def get_name(self, path):
        no_digits = lambda x: not any([i.isdigit() for i in x])
        file = path.split("/")[2].split("_")
        return "_".join(list(filter(no_digits, file)))
    
    def read_LFW(self):
        raw_lfw = [glob.glob(f"LFW/{person}/*.jpg") for person in os.listdir("LFW")]
        return raw_lfw
    
    def preprocess(self, sample_list): 
        images, labels = sample_list, [self.get_name(i) for i in sample_list]
        images = torch.stack([self.transform(i) for i in images]).float()
        labels = torch.tensor([self.label_mapping[i] for i in labels])
        return images, F.one_hot(labels, len(os.listdir("LFW")))
    
    def generate_dataloader(self):
        data = flatten(self.read_LFW())
        dataset = TensorDataset(*self.preprocess(data))
        
        indices = np.array(list(range(len(data))))
        np.random.permutation(indices)
        data_splits = [int(len(indices) * i) for i in self.splits]
        
        if self.mode == "train":
            datapoints = indices[:data_splits[0]]
        elif self.mode == "val":
            datapoints = indices[data_splits[0]:data_splits[0]+data_splits[1]]
        else: datapoints = indices[data_splits[0]+data_splits[1]:]
        
        return lambda b: DataLoader(
            dataset, 
            batch_size=b, 
            sampler=SubsetRandomSampler(np.array(datapoints))
        )

In [None]:
# Define pseudo-masking functions

def mask_lower_half_of_image(img, value, amount=0.5):
    nth_pixel = img[0].shape[0]
    half = int(nth_pixel * (1 - amount))
    minus_half = nth_pixel - half
    
    if value == "random_one":
        random = torch.rand(3,1)
        random=random.unsqueeze(1).repeat(1, half, nth_pixel)
        img[:, minus_half:, :]=random
        return img
    elif value == "random_many":
        random = torch.rand((3, half, nth_pixel))
        img[:, minus_half:, :]=random
        return img
    else:
        ones = torch.ones(3, half, nth_pixel)
        zeros = torch.zeros(3, minus_half, nth_pixel)
        mask = torch.concat((ones, zeros), 1) == 0 
        mask = torch.reshape(mask, (3, nth_pixel, nth_pixel))
        return img.masked_fill(mask, 0)

def mask_with_prob_color(img, prob, value):
    mask = (torch.rand(size=(50,50)) < prob).int()
    l, counter = [], 0
    rnd_block = torch.rand(3,1).unsqueeze(1).repeat(1,2,2)
    
    for i in mask:
        a = []
        for b in range(len(i)):
            if i[b] != torch.tensor(0):
                if value == "random_one":
                    a.append(rnd_block)
                else: a.append(torch.rand(3,2,2))
            else: 
                x_dim = counter*2
                y_dim = int(b)*2
                a.append(img[:, x_dim:x_dim+2, y_dim:y_dim+2])
        l.append(torch.concat(a, dim=2))
        mask = torch.concat(l, dim=1)
        counter+=1
    return mask
    
def mask_with_prob(img, prob):
    mask = (torch.rand(size=(50,50)) < prob).int()
    l = []
    
    for i in mask:
        a = []
        for b in i:
            a.append(b.repeat(2,2))
            
        l.append(torch.concat(a, dim=1))
        mask = torch.concat(l, dim=0)
    
    mask = torch.stack((mask,mask,mask), dim=0)
    return img.masked_fill(mask, 0)

In [None]:
# Initialize datasets

unmasked_train_lfw = UnmaskedFaces("train", transform=[lambda x: mask_lower_half_of_image(x, "random_many")]).generate_dataloader()
#unmasked_val_lfw = UnmaskedFaces("val").generate_dataloader()
#unmasked_test_lfw = UnmaskedFaces("test").generate_dataloader()

# masked_25_percent_black_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob(x, 0.25)]).generate_dataloader(indices)
# masked_50_percent_black_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob(x, 0.50)]).generate_dataloader(indices)
# masked_75_percent_black_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob(x, 0.75)]).generate_dataloader(indices)
# masked_quarter_black_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, 0, 0.25)]).generate_dataloader(indices)
# masked_half_black_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, 0, 0.5)]).generate_dataloader(indices)
# masked_three_quarters_black_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, 0, 0.75)]).generate_dataloader(indices)

# masked_25_percent_random_one_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob_color(x, 0.25, "random_one")]).generate_dataloader(indices)
# masked_50_percent_random_one_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob_color(x, 0.50, "random_one")]).generate_dataloader(indices)
# masked_75_percent_random_one_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob_color(x, 0.75, "random_one")]).generate_dataloader(indices)
# masked_quarter_random_one_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, "random_one", 0.75)]).generate_dataloader(indices)
# masked_half_random_one_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, "random_one", 0.5)]).generate_dataloader(indices)
# masked_three_quarters_random_one_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, "random_one", 0.25)]).generate_dataloader(indices)

# masked_25_percent_random_many_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob_color(x, 0.25, "random_many")]).generate_dataloader(indices)
# masked_50_percent_random_many_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob_color(x, 0.50, "random_many")]).generate_dataloader(indices)
# masked_75_percent_random_many_lfw = UnmaskedFaces("test", transform=[lambda x: mask_with_prob_color(x, 0.75, "random_many")]).generate_dataloader(indices)
# masked_quarter_random_many_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, "random_many", 0.75)]).generate_dataloader(indices)
# masked_half_random_many_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, "random_many", 0.5)]).generate_dataloader(indices)
# masked_three_quarters_random_many_lfw = UnmaskedFaces("test", transform=[lambda x: mask_lower_half_of_image(x, "random_many", 0.25)]).generate_dataloader(indices)

#truly_masked = UnmaskedFaces("train", transform=[lambda x: x]).generate_dataloader_2(
#    jax.random.randint(key, (1, 100), 10, 100)[0]
#)

In [None]:
# Visualize datasets

dataloaders = [
    unmasked_train_lfw,
    #unmasked_val_lfw, 
    #unmasked_test_lfw, 
    # masked_25_percent_black_lfw, 
    # masked_50_percent_black_lfw, 
    # masked_75_percent_black_lfw, 
    # masked_quarter_black_lfw, 
    # masked_half_black_lfw,
    # masked_three_quarters_black_lfw, 
    # masked_25_percent_random_one_lfw, 
    # masked_50_percent_random_one_lfw, 
    # masked_75_percent_random_one_lfw, 
    # masked_quarter_random_one_lfw, 
    # masked_half_random_one_lfw,
    # masked_three_quarters_random_one_lfw,
    # masked_25_percent_random_many_lfw, 
    # masked_50_percent_random_many_lfw, 
    # masked_75_percent_random_many_lfw, 
    # masked_quarter_random_many_lfw, 
    # masked_half_random_many_lfw, 
    # masked_three_quarters_random_many_lfw,
]

fig = plt.figure(figsize=(100,100))
def get_first_img(dataloader):
    for img, _ in dataloader(1):
        print(img.shape)
        return img[0]

def plot_img(img, num):
    sub = fig.add_subplot(5, 5, num + 1)
    plt.imshow(img)
    plt.axis("off")

for i in range(len(dataloaders)):
    img = get_first_img(dataloaders[i])
    img = torch.permute(img, (1, 2, 0))
    plot_img(img, i)
plt.show()

In [None]:
# Define architecture of encoder

class Encoder(hk.nets.ResNet):
    def __init__(self, projection):
        configs = hk.nets.ResNet.CONFIGS[50].copy()
        self.projection = projection 
        super().__init__(num_classes=projection if projection else 1, **configs)
        
        
    def __call__(self, inputs, is_training, memory=None, test_local_stats=False):
            out = inputs
            out = self.initial_conv(out)
            if not self.resnet_v2:
              out = self.initial_batchnorm(out, is_training, test_local_stats)
              out = jax.nn.relu(out)

            out = hk.max_pool(out,
                              window_shape=(1, 3, 3, 1),
                              strides=(1, 2, 2, 1),
                              padding="SAME")

            for block_group in self.block_groups:
              out = block_group(out, is_training, test_local_stats)

            if self.resnet_v2:
              out = self.final_batchnorm(out, is_training, test_local_stats)
              out = jax.nn.relu(out)
            out = jnp.mean(out, axis=(1, 2))
            
            if self.projection:
                return self.logits(out)
            else: return out

In [None]:
# Define (and visualize) architecture of Hopfield network 

class Hopfield(hk.Module):
    def __init__(self, beta, sim, sep, norm=False, self_retrieval=True):
        super().__init__()
        self.beta = beta
        self.sim = sim
        self.sep = sep
        self.norm = norm
        self.self_retrieval = self_retrieval
        self.encoder = Encoder(projection=None)
        self.projection = hk.Linear(output_size=5)
    
    def hopfield(self, memory, query):
        @functools.partial(jax.vmap, in_axes=(None, 0))
        def sim_sep_project(memory, query):
            memory, query = memory, jax.lax.expand_dims(query, [-1])
            sim_score = self.beta * self.sim(memory, query)
            sim_score = sim_score / jnp.sum(sim_score) if self.norm else sim_score
            sep_score = self.sep(sim_score, axis=0)
            sep_score = sep_score / jnp.sum(sep_score) if self.norm else sep_score
            out = jnp.dot(memory.T, sep_score)
            return jax.lax.squeeze(out, [1]), jax.lax.squeeze(sep_score, [1])
        return sim_sep_project(memory, query)
    
    def apply_self_retrieval(self, memory, query):
        return self.hopfield(memory, query)
        
    def apply_no_self_retrieval(self, x):
        one, two = jnp.split(x, 2)
        x_one, _ = self.hopfield(two, one)
        x_two, _ = self.hopfield(one, two)
        x = jnp.concatenate((x_one, x_two), axis=0)
        return x, None
    
    @staticmethod
    def eucdliean_distance(K, q):
        return -jnp.sum(jnp.square(q - K), axis=1)
    
    @staticmethod
    def manhattan_distance(K, q):
        return -jnp.sum(jnp.abs(q - K), axis=1)

    
    @staticmethod
    def cosine_similarity(K, q):
        return (K @ q) / (torch.norm(K) * torch.norm(q))
    
    @staticmethod
    def dot_product(K, q):
        return K @ q

    def __call__(self, x, memory, is_training):
        
        print("fulk1", str(self.self_retrieval))
        x = self.encoder(x, is_training)
        if not self.self_retrieval:
            print("fulk222")
            x = self.apply_no_self_retrieval(x)
        elif memory != None:
            print("fulk")
            x = self.apply_self_retrieval(self.encoder(memory, is_training), x)
        else: 
            x = self.apply_self_retrieval(x, x) 
        logits = self.projection(x[0])
        return logits, x[1]

In [None]:
# Define uility functions 

def get_acc(logits, labels):
    argmax = jnp.argmax(logits, axis=1)
    encoded = jax.nn.one_hot(argmax, num_classes=labels.shape[1])
    return jnp.mean(jnp.all(labels == encoded, axis=1))

In [None]:
def train_val_test(
    train_loader,
    val_loader,
    test_loaders,
    batch_size,
    epochs,
    sim,
    method_num,
    task_num,
    beta,
    encoder_only,
    pretrained_encoder,
    save_as,
    sweeps):
    sim=eval(sim)
    if not sweeps:
        wandb.init(project="final_hopfield_masked")
    if encoder_only:
        model = lambda x, is_training: Encoder(5)(x, is_training)
    elif pretrained_encoder:
        pass
    else: 
        self_retrieval=True if method_num == 1 else False
        model = lambda x, memory, is_training: Hopfield(beta, sim=sim, sep=jax.nn.softmax, self_retrieval=self_retrieval)(x, memory, is_training)
    
    wandb.config = {
        "batch_size" : batch_size,
        "epochs" : epochs,
        "beta" : beta,
        "encoder_only" : encoder_only,
        "task_num" : task_num,
        "method_num" : method_num,
        "save_as" : save_as,
        "sweeps" : sweeps
    }
    
    train_loader_fn = train_loader
    train_loader = train_loader(batch_size)
    val_loader = val_loader(batch_size)
    test_loaders = {k : v(batch_size) for k, v in test_loaders.items()}
        
    forward = hk.without_apply_rng(hk.transform_with_state(model))
    sample_input = jnp.ones([batch_size, 100, 100, 3])
    sample_labels = jnp.ones([batch_size, 5])
    rng_key = jax.random.PRNGKey(42)
    params, state = forward.init(rng_key, sample_input, None, True)
    
    model2 = lambda x, memory, is_training: Hopfield(beta, sim=sim, sep=jax.nn.softmax, self_retrieval=True)(x, memory, is_training)
    forwardl = hk.without_apply_rng(hk.transform_with_state(model2))
    
    optimizer = optax.adam(1e-4)
    opt_state = optimizer.init(params)
    
    def prep_data(data, labels):
        data, labels = jnp.asarray(data), jnp.asarray(labels)
        data = jnp.transpose(data, (0, 2, 3, 1))
        return data, labels
    
    @partial(jax.jit)
    def task1_eval(params, state, data, labels, memory):
        logits, state = forward.apply(params, state, data, memory, is_training=False) 
        #print(len(logits), logits[1], "bisk")
        loss = optax.softmax_cross_entropy(logits[0], labels)
        return loss.mean(), (logits, state)
    
    
    def task1_test(params, state, data, labels, memory):
        logits, state = forwardl.apply(params, state, data, memory, is_training=False) 
        print(len(logits), logits[1], "bisk")
        loss = optax.softmax_cross_entropy(logits[0], labels)
        return loss.mean(), (logits, state)
    
    def task1_loss(params, state, data, labels, memory):
        logits, state = forward.apply(params, state, data, memory, is_training=True)
        loss = optax.softmax_cross_entropy(logits[0], labels)
        return loss.mean(), (logits, state)
    
    @partial(jax.jit)
    def task1_update(params, state, opt_state, data, labels, memory):
        grad_fn = jax.value_and_grad(task1_loss, has_aux=True)
        ((loss, (logits, state)), grads) = grad_fn(params, state, data, labels, memory)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return params, state, opt_state, loss, logits
        
    # compile jitted functions
    task1_eval(params, state, sample_input, sample_labels, None)
    task1_update(params, state, opt_state, sample_input, sample_labels, None)
        
    def train_and_val(params, state, opt_state):
        for epoch in range(epochs):
            t_loss, t_acc = [], []
            v_loss, v_acc = [], []
            
            for data, labels in train_loader:
                if task_num == 1:
                    data, labels = prep_data(data, labels)
                    res = task1_update(params, state, opt_state, data, labels, None)
                    (params, state, opt_state, loss, (logits, _)) = res
                else: 
                    pass
                t_loss.append(loss)
                t_acc.append(get_acc(logits, labels))
            wandb.log({"train_loss" : float(jnp.array(t_loss).mean()), "epoch" : epoch})
            wandb.log({"train_accuracy" : float(jnp.array(t_acc).mean()), "epoch" : epoch})
            
            for data, labels in val_loader:
                if task_num == 1:
                    data, labels = prep_data(data, labels)
                    loss, ((logits, _), _) = task1_eval(params, state, data, labels, None)
                else:
                    pass
                v_loss.append(loss)
                v_acc.append(get_acc(logits, labels))
            wandb.log({"val_loss" : float(jnp.array(v_loss).mean()), "epoch" : epoch})
            wandb.log({"val_accuracy" : float(jnp.array(v_acc).mean()), "epoch" : epoch})
            print(f"Finished epoch {epoch}")
            
        if save_as:
            with open(save_as, 'wb') as f:
                jnp.save(f, (params, state), allow_pickle=True)
            
        return params, state
            
    def test(params, state):
        if not encoder_only:
            model = lambda x, is_training: Hopfield(beta, sim=sim, sep=jax.nn.softmax, self_retrieval=True)(x, is_training)
        if pretrained_encoder:
            pass
        forwardl = hk.without_apply_rng(hk.transform_with_state(model))
        
        mem_count = [50*i for i in range(1,9)]
        memories = [prep_data(*next(iter(train_loader_fn(i)))) for i in mem_count]
        results_dicts = {}
        
        # compiled jitted functions
        task1_test(params, state, sample_input, sample_labels, None)
        res = {}
         
        for k, v in test_loaders.items():
            acc_without_mem, acc_with_mem = [], []
            maps_without_mem, maps_with_mem = None, None
            n_labels=None
            for data, labels in v:
                data, labels = prep_data(data, labels)
                
                if task_num == 1:
                    _, ((logits, attn_map), _) = task1_test(params, state, data, labels, None)
                else:
                    pass
                acc_without_mem.append(get_acc(logits, labels))
                if maps_without_mem == None:
                    n_labels = labels
                maps_without_mem = attn_map if maps_without_mem == None else maps_without_mem
                
                
                if task_num == 1:
                    if not encoder_only:
                        accs, all_maps = [], []
                        for t_data, t_labels in memories: 
                            _, ((logits, attn_map), _) = task1_test(params, state, data, labels, t_data)
                            accs.append(get_acc(logits, labels))
                            print("orggg", attn_map.shape)
                            all_maps.append((attn_map, t_labels, labels))
                        acc_with_mem.append(accs)
                else:
                    pass
                acc_with_mem.append(accs)
                maps_with_mem = all_maps if maps_with_mem == None else maps_with_mem
                print("with mem", [maps_with_mem[i][0].shape for i in range(len(maps_with_mem))])

            wandb.log({f"{k}_test_accuracy_no_memory" : float(jnp.array(acc_without_mem).mean())})
            acc_with_mem = jnp.array(acc_with_mem).mean(axis=0).tolist()
            res[k] = [acc_with_mem, maps_with_mem]
            
            unencoded_labels = jnp.argmax(n_labels, axis=1).tolist()
            unencoded_labels = [f"{unencoded_labels[i]}({i})" for i in range(len(unencoded_labels))]
            print(len(unencoded_labels), unencoded_labels)
            plot_without_mem = pd.DataFrame(maps_without_mem, columns=unencoded_labels, index=unencoded_labels)
            plots_without_mem = px.imshow(plot_without_mem, labels=dict(x="Inputs", y="Inputs", color="Similarity"))
            wandb.log({f"{k}_attn_maps_no_memory" : plots_without_mem})
        
        keys, values = list(res.keys()), list(res.values())
        acc_titles = [f"{i}_test_accuracy_with_memory" for i in keys]   
        map_titles = [f"{i}_attn_maps_with_memory" for i in keys]
        
        for mem in range(len(memories)):
            vals = [[], []]
            for v in values:
                print(v[0], "fool")
                vals[0].append(v[0][mem])
                attn_map, t_labels, labels = v[1][mem]
                print(attn_map.shape, jnp.argmax(labels, axis=1).tolist(), jnp.argmax(t_labels, axis=1).tolist(), "attttn")
                columns = jnp.argmax(t_labels, axis=1).tolist()
                rows = jnp.argmax(labels, axis=1).tolist()
                attn_map = pd.DataFrame(attn_map.tolist(), columns=[f"{columns[i]}({i})" for i in range(len(columns))] 
                                       )
                attn_map.index = [f"{rows[i]}({i})" for i in range(len(rows))]
                attn_map = px.imshow(attn_map, labels=dict(
                    x="Memory",
                    y="Inputs", 
                    color="Similarity", 
                    title=f"N = {mem_count[mem]}"
                ))
                vals[1].append(attn_map)
            
            dict1 = {acc_titles[i] : vals[0][i] for i in range(len(keys))}
            dict2 = {map_titles[i] : vals[1][i] for i in range(len(keys))}
            dict3 = {"memories" : mem_count[mem]}
            dict1.update(dict2)
            dict1.update(dict3)
            print(dict1)
            
            wandb.log(dict1)
            
    p, s = train_and_val(params, state, opt_state)
    test(p, s)

In [None]:
!export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7

train_val_test(
    train_loader=unmasked_train_lfw,
    val_loader=unmasked_val_lfw,
    test_loaders={
        "truly_masked" : truly_masked, 
        "unmasked_test_lfw" : unmasked_test_lfw
    },
    batch_size=32,
    epochs=1,
    sim=Hopfield.dot_product,
    method_num=2,
    task_num=1,
    beta=0.001,
    encoder_only=False,
    pretrained_encoder=False,
    save_as=False,
    sweeps=False)
    

In [None]:
def beta_sweep():
    with wandb.init() as _:
        config = wandb.config
        !export XLA_PYTHON_CLIENT_MEM_FRACTION=0.7

    train_val_test(
        train_loader=unmasked_train_lfw,
        val_loader=unmasked_val_lfw,
        test_loaders={
            "truly_masked" : truly_masked, 
            "unmasked_test_lfw" : unmasked_test_lfw
        },
        batch_size=config["batch"],
        epochs=1,
        sim=config["sim"],
        method_num=config["method"],
        task_num=1,
        beta=config["beta"],
        encoder_only=False,
        pretrained_encoder=False,
        save_as=False,
        sweeps=True)

In [None]:
jnp.dot(jnp.ones((2048, 32)), jnp.ones((2048, 1)))

In [None]:
s=jnp.ones((100,1))
jax.lax.squeeze(s, [1]).shape

In [None]:
sweep_config = {
  "name" : "(32, beta)",
  "method" : "grid",
  "parameters" : {
    "beta"  : {"values" : [float(i) for i in 
                           [32,16,8,4,2,0.1, 0.01, 0.001, 
                            0.0001]
                          ]}, 
    "method" : {"values": [1,2]},
    "batch" : {"values":[32, 64, 128]}, 
    "sim" : {"values":["Hopfield.eucdliean_distance", "Hopfield.manhattan_distance", "Hopfield.cosine_similarity", "Hopfield.dot_product"]}
  }}

sweep_id = wandb.sweep(sweep_config, project="final_hopfield_masked")
wandb.agent(sweep_id, function=beta_sweep)

In [None]:
def eucdliean_distance(K, q):
        return -jnp.sum(jnp.square(q - K), axis=1)
    
    @staticmethod
    def manhattan_distance(K, q):
        return -jnp.sum(jnp.abs(q - K), axis=1)

    
    @staticmethod
    def cosine_similarity(K, q):
        return (K @ q) / (torch.norm(K) * torch.norm(q))
    
    @staticmethod
    def dot_product(K, q):

In [None]:
import torch

In [None]:
torch.cuda.is_available()

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 4

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net().to("cuda")
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

In [None]:
import numpy as np
t=  lambda x: np.random.choice(3,1)
t(1)

In [None]:
np.stack([ t(1) for i in range(100)])

In [None]:
import pytorch_lightning as pl
from data import UnmaskedFaces 
from pytorch_lightning.loggers import WandbLogger
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr.transforms import (
    SimCLRTrainDataTransform,  SimCLREvalDataTransform
)


wandb_logger = WandbLogger(name='SimCLR_32', project='associative-vision-models')


unmasked_train = UnmaskedFaces("train", SimCLRTrainDataTransform)
unmasked_val = UnmaskedFaces("val", SimCLREvalDataTransform)


trainer = pl.Trainer(max_epochs=1, logger=wandb_logger, gpus=1, default_root_dir="simclr")
trainer.fit(model, train_dataloader, val_dataloader)

In [None]:
import torchvision.models

In [None]:
torchvision.models.resnet50_bn()

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet50
import torch.nn.functional as F
encoder = resnet50()
encoder.fc = nn.Identity()

In [None]:
class test(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.layer1 = nn.Linear(10, 1)
        print(list(self.parameters()))
        self.parameters = self.exclude_from_wt_decay(self.named_parameters(),1)
        print(list(self.parameters()))
        
    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn']):
        params = []
        excluded_params = []

        for name, param in named_params:
            if not param.requires_grad:
                continue
            elif any(layer_name in name for layer_name in skip_list):
                excluded_params.append(param)
            else:
                params.append(param)

        return [
            {'params': params, 'weight_decay': weight_decay},
            {'params': excluded_params, 'weight_decay': 0.}
        ]

In [6]:
import os
import glob
import torch
import numpy as np
from PIL import Image
import torch.nn.functional as F
from torchvision import transforms as T
from pydash import flatten_deep as flatten
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler

class GenericDataset(Dataset):
    
    def __init__(self, 
                 data, 
                 labels, 
                 transforms, 
                 train_split, 
                 val_split, 
                 test_split):
        
        self.data = data
        self.labels = labels
        self.transforms = transforms 
        
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        
        self.setup()
    
    def setup(self):
        for i in self.train_split:
            self.data[i] = self.transforms[0](self.data[i])
        
        for i in self.val_split:
            if len(self.transforms) >= 2:
                t = self.transforms[1]
            else: t = self.transforms[-1]
            self.data[i] = t(self.data[i])
            
        for i in self.test_split:
            if len(self.transforms) >= 3:
                t = self.transforms[2]
            else: t = self.transforms[-1]
            self.data[i] = t(self.data[i])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


class UnmaskedFaces():
    def __init__(self, transforms, splits=[0.8, 0.1]):
        super().__init__()
        np.random.seed(42)
        
        self.splits = splits
        self.transforms = [T.Compose([Image.open, T.ToTensor()]) for t in transforms]
        self.label_mapping = {i[1] : i[0] for i in enumerate(os.listdir("LFW"))}
        
    @staticmethod
    def mask_lower_half_of_image(img, value, amount=0.5):
        nth_pixel = img[0].shape[0]
        half = int(nth_pixel * (1 - amount))
        minus_half = nth_pixel - half

        if value == 0:
            random = torch.rand(3,1)
            random=random.unsqueeze(1).repeat(1, half, nth_pixel)
            img[:, minus_half:, :]=random
            return img
        elif value == 1:
            random = torch.rand((3, half, nth_pixel))
            img[:, minus_half:, :]=random
            return img
        elif value == 2:
            ones = torch.ones(3, half, nth_pixel)
            zeros = torch.zeros(3, minus_half, nth_pixel)
            mask = torch.concat((ones, zeros), 1) == 0 
            mask = torch.reshape(mask, (3, nth_pixel, nth_pixel))
            return img.masked_fill(mask, 0)
        
    @staticmethod
    def random_masking_transform():
        mask_fn = lambda x: mask_lower_half_of_image(x, random.choice(2,1))
        return T.Lambda(mask_fn)

    def get_name(self, path):
        no_digits = lambda x: not any([i.isdigit() for i in x])
        file = path.split("/")[2].split("_")
        return "_".join(list(filter(no_digits, file)))
    
    def read_LFW(self):
        raw_lfw = [glob.glob(f"LFW/{person}/*.jpg") for person in os.listdir("LFW")]
        return raw_lfw
    
    def preprocess(self, sample_list): 
        images, labels = sample_list, [self.get_name(i) for i in sample_list]
        labels = torch.tensor([self.label_mapping[i] for i in labels])
        return images, F.one_hot(labels, len(os.listdir("LFW")))
    
    def generate_dataloader(self):
        data = flatten(self.read_LFW())
        
        indices = np.array(list(range(len(data))))
        np.random.permutation(indices)
        data_splits = [int(len(indices) * i) for i in self.splits]
        
        train_split = indices[:data_splits[0]]
        val_split = indices[data_splits[0]:data_splits[0]+data_splits[1]]
        test_split = indices[data_splits[0]+data_splits[1]:]
        
        dataset = GenericDataset(*self.preprocess(data), 
                                 self.transforms,
                                 train_split,
                                 val_split,
                                 test_split)
        
        return lambda b, m: DataLoader(
            dataset, 
            batch_size=b, 
            sampler=SubsetRandomSampler(
                train_split if m == "train" else
                val_split if m == "val" else
                test_split)
        )

In [8]:
s = UnmaskedFaces([lambda x: x]).generate_dataloader()

In [3]:
!ulimit -n 10000

In [9]:
s(64, "train")

<torch.utils.data.dataloader.DataLoader at 0x7faaaf6a1610>