In [None]:
class PosEmb(nn.Module):
  def __init__(self):
    super(PosEmb, self).__init__()
    self.conv = nn.Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
    self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2)
    self.activation = nn.GELU()
  def forward(self, x):
    x = self.conv(x)
    x = x[:, :, :-1]
    return torch.permute(self.activation(x), (0,2,1))

In [None]:
import torch.nn as nn
import torch

class FeatureEncoder(nn.Module):
  def __init__(self, K, S, class_dim):
    super(FeatureEncoder, self).__init__()
    self.conv0 = []
    self.K = K
    self.S = S
    self.conv0.append(nn.Sequential(
        nn.Conv1d(128, 512, kernel_size=(self.K[0],), stride=(self.S[0],), bias=False),
        nn.GELU(),
        nn.GroupNorm(512, 512, eps=1e-05, affine=True)
    ))
    self.conv1 = []
    for i in range(4):
      self.conv1.append(nn.Sequential(
          nn.Conv1d(512, 512, kernel_size=(self.K[1],), stride=(self.S[1],), bias=False),
          nn.GELU()
      ))
    self.conv2 = []
    for i in range(2):
      self.conv2.append(nn.Sequential(
          nn.Conv1d(512, 512, kernel_size=(self.K[5],), stride=(self.S[5],), bias=False),
          nn.GELU()
        ))
    self.extractor = nn.Sequential(*self.conv0, *self.conv1, *self.conv2)
    self.projection = nn.Sequential(
        nn.LayerNorm((512,), eps=1e-05, elementwise_affine=True),
        nn.Linear(in_features=512, out_features=768, bias=True),
        nn.Dropout(p=0.1, inplace=False)
    )
    self.pos_emb = PosEmb()
    self.norm = nn.Sequential(
        nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True),
        nn.Dropout(p=0.1, inplace=False)
    )
    self.attention = []
    for i in range(12):
      self.attention.append(nn.MultiheadAttention(768, 1))
    self.feed_forward = nn.Sequential(
        nn.Dropout(p=0.1, inplace=False),
        nn.Linear(in_features=768, out_features=3072, bias=True),
        nn.GELU(),
        nn.Linear(in_features=3072, out_features=768, bias=True),
        nn.Dropout(p=0.1, inplace=False),
        nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
    self.classify = nn.Sequential(
        nn.Conv2d(class_dim, 16, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2, padding=0),
        nn.Flatten(),
        nn.Linear(384, 128),  # Adjusted input size for the fully connected layer
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
        )

  def receptive_field(self):
    St = 1
    stride = self.S[::-1]
    for s in stride:
      St *= s
    R = 1
    i = 0
    kernel = self.K[::-1]
    for k in kernel:
      R = R*stride[i] + (k - stride[i])
      i += 1
    return R, St

  def feature_encoder(self, x):
    x = self.extractor(x)
    x = torch.permute(x, (0,2,1))
    return torch.permute(self.projection(x), (0,2,1))

  def context_encoder(self, x):
    x = self.pos_emb(x) + torch.permute(x, (0,2,1))
    x = self.norm(x)
    for i in range(12):
      x = self.attention[i](x, x, x)[0]
    return torch.permute(self.feed_forward(x), (0,2,1))

  def forward(self, x):
    x = self.feature_encoder(x)
    x = torch.flatten(x, start_dim = 2)
    x = self.context_encoder(x)
    return x


In [None]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from joblib import Parallel, delayed
import os
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

class MaskedContrastiveLearningTask():
    def __init__(self,
                dataset: torch.utils.data.Dataset,
                task_params={
                    'mask_prob': 0.5
                },
                train_params={
                    'num_epochs': 100,
                    'batch_size': 10,
                    'print_every': 10
                },
                verbose=False
        ):
        self.dataset = dataset
        self.train_test_split()

        self.train_params = train_params
        self.mask_probability = task_params['mask_prob']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.verbose=verbose

    def train_test_split(self):
        generator = torch.Generator().manual_seed(42)
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(self.dataset, [0.7,0.3], generator=generator)

    def forward(self, model, x):
        '''
        Forward pass of the model
        @parameter
            model:  nn.Module   model
            x    :  tensor      (N x C x T) batched raw input
        @return
            prediction:         (N x D x K) Batch-size embeddings of the model's guess for masked inputs
            masked_latent:      (N x D x K) Batch-size embeddings of the feature encoder output of true masked inputs
            foil_latents:       (N x D x K) Batch-size embeddings of the feature conder output of the foil inputs
        '''
        print(x.shape)
        embeddings = model.feature_encoder(x) # N x D x K
                                              # forward pass of feature encoder generate intermediary embeddings
        if self.verbose:
            print('feature encoder output shape', embeddings.shape)

        # learned masked vector embedding
        masked_vector_learned_embedding = torch.ones((embeddings.shape[0], embeddings.shape[1])) # N x D # TODO
        if self.verbose:
            print('learned masked embeddings shape', masked_vector_learned_embedding.shape)

        # select from the sampled segment L masked inputs
        masked_indices = np.random.choice(embeddings.shape[-1], size=(int(self.mask_probability*embeddings.shape[-1]),), replace=False)
        if self.verbose:
            print('masked indices shape', masked_indices.shape)
        # replace the selected indices with the masked vector embedding
        true_masked_embeddings = embeddings[:,:,masked_indices] # N x D x K # .detach().clone()
        if self.verbose:
            print('true masked embeddings shape', true_masked_embeddings.shape)

        learned_embeddings_replace = embeddings.clone() # if not clone backward pass will complain as inplace modification not allowed
        for i in range(len(masked_indices)):
            learned_embeddings_replace[:,:,i] = masked_vector_learned_embedding
        if self.verbose:
            print('masked embeddings shape', embeddings.shape)

        # feed masked samples to context encoder. Every timestep has an output
        context_encoder_outputs = model.context_encoder(learned_embeddings_replace) # N x D x K
        if self.verbose:
            print('context encoder outputs shape', context_encoder_outputs.shape)

        # context encoder_outputs of the masked input
        predicted_masked_latent = context_encoder_outputs[:,:,masked_indices] # N x D x K
        if self.verbose:
            print('predicted context encoder outputs shape', predicted_masked_latent.shape)
        return predicted_masked_latent, true_masked_embeddings

    def loss(self, predictions, masked_latents):
        '''
        Follow implementation in https://github.com/dhruvbird/ml-notebooks/blob/main/nt-xent-loss/NT-Xent%20Loss.ipynb
        @parameter
            predictions:         (N x D x K) Batch-size embeddings of the model's guess for masked inputs
            masked_latents:      (N x D x K) Batch-size embeddings of the feature encoder output of masked inputs

        @return
            batched mean contrastive loss
        '''
        losses = torch.zeros((masked_latents.shape[-1],), device=self.device)
        # contrastive learning is computed one masked sample at a time
        for k in range(masked_latents.shape[-1]):
            predicted_masked_latent = predictions[:,:,k] # N x D
            if self.verbose:
                print('predicted masked latent shape', predicted_masked_latent.shape)
            cos_sim = F.cosine_similarity(torch.unsqueeze(predicted_masked_latent, dim=-1), masked_latents, dim=1) # N x K
            if self.verbose:
                print('cosine similarity shape', cos_sim.shape)
            labels = torch.zeros([cos_sim.shape[0], cos_sim.shape[1]], device=self.device) # N x K
            labels[:,k] = 1
            # print('labels', labels)
            # losses.append(F.cross_entropy(cos_sim, labels, reduction='mean'))
            losses[k] = F.cross_entropy(cos_sim, labels, reduction='mean')
        if self.verbose:
            print('losses', losses)
        # return torch.mean(torch.tensor(losses))
        return torch.mean(losses)

    def train(self, model, train_params={}):
        print('Training on ', self.device)
        self.train_params.update(train_params)
        num_epochs = self.train_params['num_epochs']
        batch_size = self.train_params['batch_size']
        print_every = self.train_params['print_every']

        optimizer  = torch.optim.Adam(model.parameters())
        dataloader_train = DataLoader(self.dataset_train, batch_size = batch_size, shuffle = True)
        model.to(device=self.device)
        model.train()
        for e in range(num_epochs):
            for t, (samples, _) in enumerate(dataloader_train):
                samples = samples.to(device=self.device, dtype=torch.float32)
                predictions, masked_latents = self.forward(model, samples)
                loss = self.loss(predictions, masked_latents)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if t % print_every == 0:
                    # writer.add_scalar("Loss/train", loss.item(), e*len(dataloader)+t)
                    print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))

                metrics = {"train/train_loss": loss.item()}

                del samples
                del predictions
                del masked_latents
                del loss

            eval_train_score, eval_test_score = self.finetune_eval_score(model)

    def finetune_eval_score(self, model):
        model.eval()
        generator = torch.Generator().manual_seed(42)
        val_train, val_test = random_split(self.dataset_val, [0.7, 0.3], generator=generator)
        val_train_dataloader = DataLoader(val_train, batch_size = len(val_train), shuffle = True)
        val_test_dataloader = DataLoader(val_test, batch_size = len(val_test), shuffle = True)

        samples, labels = next(iter(val_train_dataloader))
        samples = samples.to(device=self.device, dtype=torch.float32)
        predictions = model(samples)
        # print(predictions)
        embeddings = torch.mean(predictions, dim=-1) # TODO is averaging the best strategy here, for classification?
        # print(embeddings)
        clf = LinearDiscriminantAnalysis()
        clf.fit(embeddings.detach().cpu().numpy(), labels.detach().cpu().numpy())
        train_score = clf.score(embeddings.detach().cpu().numpy(), labels.detach().cpu().numpy())
        print('Eval train score:', train_score)

        samples_test, labels_test = next(iter(val_test_dataloader))
        samples_test = samples_test.to(device=self.device, dtype=torch.float32)
        predictions = model(samples_test)
        embeddings = torch.mean(predictions, dim=-1) # TODO is averaging the best strategy here, for classification?
        test_score = clf.score(embeddings.detach().cpu().numpy(), labels_test.detach().cpu().numpy())
        print('Eval test score:', test_score)
        return train_score, test_score

In [None]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from joblib import Parallel, delayed
import os
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

class RelativePositioningTask():
    def __init__(self,
                dataset: torch.utils.data.Dataset,
                win_length = 50,
                tau_pos = 150,
                tau_neg = 170,
                n_samples = 1,
                task_params={
                    'mask_prob': 0.5
                },
                train_params={
                    'num_epochs': 100,
                    'batch_size': 10,
                    'print_every': 10
                },
                verbose=False
        ):
        self.dataset = dataset
        self.train_test_split()
        self.win = win_length
        self.tau_pos = tau_pos
        self.tau_neg = tau_neg
        self.n_samples = n_samples

        self.train_params = train_params
        self.mask_probability = task_params['mask_prob']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.verbose=verbose

    def train_test_split(self):
        generator = torch.Generator().manual_seed(42)
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(self.dataset, [0.7,0.3], generator=generator)

    def forward(self, model, x):
        samples = []
        labels = []

        for anchor_start in np.arange(0, x.shape[2]-self.win, self.win): # non-overlapping anchor window
            # Positive window start t_pos:
            #     - |t_pos - t_anchor| <= tau_pos
            #           <-> t_pos <= tau_pos + t_anchor
            #           <-> t_pos => t_anchor - tau_pos
            #     - t_pos < T - win
            #.    - t_pos > 0
            pos_winds_start = np.arange(np.maximum(0, anchor_start - self.tau_pos), np.minimum(anchor_start+self.tau_pos, x.shape[2]-self.win), self.win) # valid positive samples onsets
            if len(pos_winds_start) > 0:
                # positive context
                pos_winds = [x[:, :, sample_start:sample_start+self.win] for sample_start in np.random.choice(pos_winds_start, self.n_samples, replace=False)]
                anchors = [x[:, :,anchor_start:anchor_start+self.win] for i in range(len(pos_winds))] # repeat same anchor window

                anch = torch.stack([anchors[i].clone().detach() for i in range(len(anchors))])[0]
                pos_w = torch.stack([pos_winds[i].clone().detach() for i in range(len(anchors))])[0]

                samples.append(torch.stack([anch, pos_w])) # if anchors[i].shape == pos_winds[i].shape])
                labels.append(torch.ones(len(anchors)))

                # negative context
                # Negative window start t_neg:
                #     - |t_neg - t_anchor| > tau_neg
                #           <-> t_neg > tau_neg + t_anchor
                #           <-> t_neg < t_anchor - tau_neg
                #     - t_neg < T - win
                #.    - t_neg > 0
                neg_winds_start = np.concatenate((np.arange(0, anchor_start-self.tau_neg, self.win), np.arange(anchor_start+self.tau_neg, x.shape[2]-self.win, self.win)))
                neg_winds = [x[:, :,sample_start:sample_start+self.win] for sample_start in np.random.choice(neg_winds_start, self.n_samples, replace=False)]

                anch = torch.stack([anchors[i].clone().detach() for i in range(len(anchors))])[0]
                neg_w = torch.stack([neg_winds[i].clone().detach() for i in range(len(anchors))])[0]

                samples.append(torch.stack([anch, neg_w])) # if anchors[i].shape == neg_winds[i].shape])
                labels.append(torch.zeros(len(anchors)))

        samples = torch.stack(samples) # N x 2 (anchors, pos/neg) x C x W
        if len(samples) != len(labels):
            raise ValueError('Number of samples and labels mismatch')
        labels = torch.stack(labels)

        embeddings = []

        for i in range(samples.shape[0]):
          embeddings.append(model.feature_encoder(samples[i][:, 0]))

        predictions = []

        for i in range(len(embeddings)):
          predictions.append(model.classify(embeddings[i]))

        predictions = torch.stack(predictions)
        labels = labels.long()

        return predictions, labels

    def loss(self, predictions, labels):
        return F.cross_entropy(predictions, labels)

    def train(self, model, train_params={}):
        print('Training on ', self.device)
        self.train_params.update(train_params)
        num_epochs = self.train_params['num_epochs']
        batch_size = self.train_params['batch_size']
        print_every = self.train_params['print_every']

        optimizer  = torch.optim.Adam(model.parameters())
        dataloader_train = DataLoader(self.dataset_train, batch_size = batch_size, shuffle = True)
        model.to(device=self.device)
        model.train()
        for e in range(num_epochs):
            for t, (samples, _) in enumerate(dataloader_train):
                samples = samples.to(device=self.device, dtype=torch.float32)
                predictions, labels = self.forward(model, samples)
                loss = self.loss(predictions, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if t % print_every == 0:
                    # writer.add_scalar("Loss/train", loss.item(), e*len(dataloader)+t)
                    print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))

                metrics = {"train/train_loss": loss.item()}

                del samples
                del predictions
                del labels
                del loss

            eval_train_score, eval_test_score = self.finetune_eval_score(model)

    def finetune_eval_score(self, model):
        model.eval()
        generator = torch.Generator().manual_seed(42)
        val_train, val_test = random_split(self.dataset_val, [0.7, 0.3], generator=generator)
        val_train_dataloader = DataLoader(val_train, batch_size = len(val_train), shuffle = True)
        val_test_dataloader = DataLoader(val_test, batch_size = len(val_test), shuffle = True)

        samples, labels = next(iter(val_train_dataloader))
        samples = samples.to(device=self.device, dtype=torch.float32)
        predictions = model(samples)
        # print(predictions)
        embeddings = torch.mean(predictions, dim=-1) # TODO is averaging the best strategy here, for classification?
        # print(embeddings)
        clf = LinearDiscriminantAnalysis()
        clf.fit(embeddings.detach().cpu().numpy(), labels.detach().cpu().numpy())
        train_score = clf.score(embeddings.detach().cpu().numpy(), labels.detach().cpu().numpy())
        print('Eval train score:', train_score)

        samples_test, labels_test = next(iter(val_test_dataloader))
        samples_test = samples_test.to(device=self.device, dtype=torch.float32)
        predictions = model(samples_test)
        embeddings = torch.mean(predictions, dim=-1) # TODO is averaging the best strategy here, for classification?
        test_score = clf.score(embeddings.detach().cpu().numpy(), labels_test.detach().cpu().numpy())
        print('Eval test score:', test_score)
        return train_score, test_score

In [None]:
import torch
from torch.nn import functional as F
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
from joblib import Parallel, delayed
import os
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

class TemporalShufflingTask():
    def __init__(self,
                dataset: torch.utils.data.Dataset,
                win_length = 50,
                tau_pos = 150,
                tau_neg = 151,
                n_samples = 1,
                stride = 1,
                task_params={
                    'mask_prob': 0.5
                },
                train_params={
                    'num_epochs': 100,
                    'batch_size': 10,
                    'print_every': 10
                },
                verbose=False
        ):
        self.dataset = dataset
        self.train_test_split()
        self.win = win_length
        self.tau_pos = tau_pos
        self.tau_neg = tau_neg
        self.n_samples = n_samples
        self.stride = stride

        self.train_params = train_params
        self.mask_probability = task_params['mask_prob']
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.verbose=verbose

    def train_test_split(self):
        generator = torch.Generator().manual_seed(42)
        self.dataset_train, self.dataset_val = torch.utils.data.random_split(self.dataset, [0.7,0.3], generator=generator)

    def forward(self, model, x):
        samples = []
        labels = []

        tau_pos = self.tau_pos
        for pos_start in np.arange(0, x.shape[2], tau_pos): # non-overlapping positive contexts
            if pos_start + tau_pos < x.shape[2]:
                pos_winds = [x[:, :, pos_start:pos_start+self.win], x[:, :, pos_start+self.win*2:pos_start+self.win*3]] # two positive windows\
                inorder = torch.stack(pos_winds[:1] + [x[:, :, pos_start+self.win:pos_start+self.win*2]] + pos_winds[1:])
                samples.extend([inorder, torch.flip(inorder, dims = [0])])
                labels.extend(torch.ones(2))

                # for negative windows, want both sides of anchor window
                neg_winds_start = np.concatenate((np.arange(0, pos_start-self.tau_neg-self.win, self.stride), np.arange(pos_start+tau_pos+self.tau_neg, x.shape[2]-self.win, self.stride)))
                selected_neg_start = np.random.choice(neg_winds_start, 1, replace=False)[0]
                disorder = torch.stack(pos_winds[:1] + [x[:,:,selected_neg_start:selected_neg_start+self.win]] + pos_winds[1:]) # two positive windows, disorder sample added to the end
                samples.extend([disorder, torch.flip(disorder, dims = [0])])
                labels.extend(torch.zeros(2))

        samples = torch.stack(samples)
        labels = torch.stack(labels).unsqueeze(1)
        if len(samples) != len(labels):
            raise ValueError('Number of samples and labels mismatch')

        embeddings = []

        for i in range(samples.shape[0]):
          embeddings.append(model.feature_encoder(samples[i][:, 0]))

        predictions = []

        for i in range(len(embeddings)):
          predictions.append(model.classify(embeddings[i]))

        predictions = torch.stack(predictions)
        labels = labels.long()

        return predictions, labels

    def loss(self, predictions, labels):
        return F.cross_entropy(predictions, labels)

    def train(self, model, train_params={}):
        print('Training on ', self.device)
        self.train_params.update(train_params)
        num_epochs = self.train_params['num_epochs']
        batch_size = self.train_params['batch_size']
        print_every = self.train_params['print_every']

        optimizer  = torch.optim.Adam(model.parameters())
        dataloader_train = DataLoader(self.dataset_train, batch_size = batch_size, shuffle = True)
        model.to(device=self.device)
        model.train()
        for e in range(num_epochs):
            for t, (samples, _) in enumerate(dataloader_train):
                samples = samples.to(device=self.device, dtype=torch.float32)
                predictions, labels = self.forward(model, samples)
                loss = self.loss(predictions, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                if t % print_every == 0:
                    # writer.add_scalar("Loss/train", loss.item(), e*len(dataloader)+t)
                    print('Epoch %d, Iteration %d, loss = %.4f' % (e, t, loss.item()))

                metrics = {"train/train_loss": loss.item()}

                del samples
                del predictions
                del labels
                del loss

            eval_train_score, eval_test_score = self.finetune_eval_score(model)

    def finetune_eval_score(self, model):
        model.eval()
        generator = torch.Generator().manual_seed(42)
        val_train, val_test = random_split(self.dataset_val, [0.7, 0.3], generator=generator)
        val_train_dataloader = DataLoader(val_train, batch_size = len(val_train), shuffle = True)
        val_test_dataloader = DataLoader(val_test, batch_size = len(val_test), shuffle = True)

        samples, labels = next(iter(val_train_dataloader))
        samples = samples.to(device=self.device, dtype=torch.float32)
        predictions = model(samples)
        # print(predictions)
        embeddings = torch.mean(predictions, dim=-1) # TODO is averaging the best strategy here, for classification?
        # print(embeddings)
        clf = LinearDiscriminantAnalysis()
        clf.fit(embeddings.detach().cpu().numpy(), labels.detach().cpu().numpy())
        train_score = clf.score(embeddings.detach().cpu().numpy(), labels.detach().cpu().numpy())
        print('Eval train score:', train_score)

        samples_test, labels_test = next(iter(val_test_dataloader))
        samples_test = samples_test.to(device=self.device, dtype=torch.float32)
        predictions = model(samples_test)
        embeddings = torch.mean(predictions, dim=-1) # TODO is averaging the best strategy here, for classification?
        test_score = clf.score(embeddings.detach().cpu().numpy(), labels_test.detach().cpu().numpy())
        print('Eval test score:', test_score)
        return train_score, test_score

In [None]:
class RandomDataset(Dataset):
  def __init__(self):
    self.x = 10
  def __len__(self):
    return 100
  def __getitem__(self, idx):
    return torch.randn(128, 512), 1

In [None]:
data = RandomDataset()
task = TemporalShufflingTask(data)
model = FeatureEncoder([10, 3, 3, 3, 3, 2, 2], [2, 1, 1, 1, 1, 1, 1], 3)
task.train(model, {'num_epochs': 10, 'batch_size': 1, 'print_every': 1})

Training on  cpu
Epoch 0, Iteration 0, loss = 3.4655
Epoch 0, Iteration 1, loss = 3.4324
Epoch 0, Iteration 2, loss = 3.2957
Epoch 0, Iteration 3, loss = 3.1918
Epoch 0, Iteration 4, loss = 3.1059
Epoch 0, Iteration 5, loss = 3.0360
Epoch 0, Iteration 6, loss = 2.9655
Epoch 0, Iteration 7, loss = 2.8876


KeyboardInterrupt: 

In [None]:
data = RandomDataset()
task = MaskedContrastiveLearningTask(data)
model = FeatureEncoder([10, 3, 3, 3, 3, 2, 2], [2, 1, 1, 1, 1, 1, 1])
task.train(model, {'num_epochs': 10, 'batch_size': 1, 'print_every': 1})

Training on  cpu
torch.Size([1, 128, 512])


RuntimeError: Given groups=1, weight of size [512, 128, 10], expected input[1, 1, 512] to have 128 channels, but got 1 channels instead

In [None]:
class TransformerLayer(nn.Module):
  def __init__(self):
    super(TransformerLayer, self).__init__()
    self.pos_emb = PosEmb()
    self.norm = nn.Sequential(
        nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True),
        nn.Dropout(p=0.1, inplace=False)
    )
    self.attention = []
    for i in range(12):
      self.attention.append(nn.MultiheadAttention(768, 1))
    self.feed_forward = nn.Sequential(
        nn.Dropout(p=0.1, inplace=False),
        nn.Linear(in_features=768, out_features=3072, bias=True),
        nn.GELU(),
        nn.Linear(in_features=3072, out_features=768, bias=True),
        nn.Dropout(p=0.1, inplace=False),
        nn.LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
  def forward(self, x):
    x = self.pos_emb(x) + torch.permute(x, (0,2,1))
    x = self.norm(x)
    for i in range(12):
      x = self.attention[i](x, x, x)[0]
    return torch.permute(self.feed_forward(x), (0,2,1))

In [None]:
model_conv = FeatureEncoder([10, 3, 3, 3, 3, 2, 2], [2, 1, 1, 1, 1, 1, 1])
print(model_conv.receptive_field())
model_encoder = TransformerLayer()
print(model_encoder)

(30, 2)




TransformerLayer(
  (pos_emb): PosEmb(
    (conv): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
    (activation): GELU(approximate='none')
  )
  (norm): Sequential(
    (0): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (1): Dropout(p=0.1, inplace=False)
  )
  (feed_forward): Sequential(
    (0): Dropout(p=0.1, inplace=False)
    (1): Linear(in_features=768, out_features=3072, bias=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=3072, out_features=768, bias=True)
    (4): Dropout(p=0.1, inplace=False)
    (5): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)


In [None]:
def contrastive_loss(y, z):
  mat = torch.einsum('bdj,bdk->bjk', y, z)
  N = y.shape[2]
  target = torch.arange(N)
  loss = 0
  for i in range(y.shape[0]):
    loss += torch.nn.functional.cross_entropy(mat[0], target)
  return loss/y.shape[0]

In [None]:
import random
x = torch.randn(8, 128, 512)
y = model_conv(x)
print(y.shape)
z = model_encoder(torch.randn(8, 768, 720))
print(z.shape)
# y1 = y
# rlist = []
# for i in range(16):
#   rlist.append(random.randint(1, 112))
# for r in rlist:
#   y1[:, :, r] = 0
# z = model_encoder(y1)
# z1 = z[:, :, rlist]
# y2 = y[:, :, rlist]
# print(contrastive_loss(y2, z1))

torch.Size([8, 768, 242])
torch.Size([8, 768, 720])


In [None]:
model_encoder(y).shape

torch.Size([8, 768, 114])

In [None]:
import torch.nn as nn
import torch

class KrakenEncoder(nn.Module):
  def __init__(self):
    super(KrakenEncoder, self).__init__()
    self.conv0 = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size = (4, 2), stride = (4, 2)),
        nn.LeakyReLU(),
        nn.GroupNorm(1,64)
    )
    self.conv1 = nn.Sequential(
        nn.Conv2d(64, 128, kernel_size = (4, 2)),
        nn.LeakyReLU(),
        nn.GroupNorm(1,128)
    )
    self.conv2 = nn.Sequential(
        nn.MaxPool2d((4, 2), stride = (1, 2)),
        nn.Conv2d(128, 256, kernel_size = (3, 3)),
        nn.MaxPool2d((4, 2), stride = (1, 2))
    )
    self.bilstm = nn.LSTM(10*256, 256, 3, bidirectional= True, batch_first=True)
  def forward(self, x):
    x = self.conv0(x)
    x = self.conv1(x)
    x = self.conv2(x)
    print(x.shape)
    x = torch.transpose(x,2,3)
    x = torch.flatten(x, start_dim = 1, end_dim = 2)
    x = torch.transpose(x,1,2)
    print('input dim before LSTM', x.shape)
    x = self.bilstm(x)
    return torch.transpose(x[0],1,2)

In [None]:
encoder = KrakenEncoder()
x = torch.randn(8, 1, 1024, 96) # why 1?
print(encoder(x).shape) # 256 is sequence length which is wrong. It's the channel/output dim

torch.Size([8, 256, 245, 10])
input dim before LSTM torch.Size([8, 245, 2560])
torch.Size([8, 512, 245])


In [None]:
encoder = KrakenEncoder()
pytorch_total_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
print(pytorch_total_params)

9286976
