In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module

import pickle
import gzip

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, k_cluster, t, v, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.k_cluster = k_cluster
        self.t = t
        self.v = v
        self.weight = nn.Parameter(torch.rand(k_cluster, in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()


    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj, c):
        b = input.size()[0] 
        # c*w  , c.shape = (b,n,k)  
        output = torch.mm(c.view(-1,self.k_cluster) , self.weight.view(self.k_cluster, -1))  # c * w  , out : (b, n,f1*f2)
        output = output.view(-1, self.v, self.in_features, self.out_features)  #b , n, f1, f2

        # xcw  , (b, n ,t , f1)  * (b,  n , f1 , f2)   , output: (b, n,t , f2)  
        input = input.contiguous().view(-1, self.t , self.in_features)
        output = output.contiguous().view(-1,self.in_features , self.out_features)
        
        output = torch.bmm(input, output).view(-1,self.v , self.t , self.out_features )

        # gcn , (b,n, n) * ( b, n , t*f2)
        graph = adj
        graph.unsqueeze(-1)
        graph = graph.repeat(b ,1,1)  # shape : (b , n , n) 
        output = torch.bmm(graph, output.view( b, self.v, self.t * self.out_features))
        output = output.view(-1, self.v , self.t , self.out_features)

        if self.bias is not None:
            return output + self.bias
        else:
            return output


    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

    def compute_output_shape(self, input_shapes):
        batch_size = input_shapes[0]
        graph_size = input_shapes[1]
        return (batch_size, graph_size, self.units)

In [5]:
class ClusterBlock(nn.Module):
    def __init__(self, n_input, n_output, v ,T, k_cluster):
        super(ClusterBlock, self).__init__( )
        self.n_input = n_input
        self.n_output = n_output
        self.v = v  # number of node
        self.T = T
        self.k_cluster = k_cluster
        self.linear1 = nn.Linear(n_input , 1)  # Dense
        self.linear2 = nn.Linear(T ,k_cluster ) # clustering
        self.softmax = nn.Softmax(dim =-1)
        self.gc1 = GraphConvolution(n_input, n_output,k_cluster,self.T , self.v)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x, graph):  # x.shape = (b,v,T,f)
        # squeeze f
        out = self.linear1(x)    #  out.shape = (b , v ,T,1)
        out = out.view(-1  , self.v , self.T)  #  out.shape = (b , v ,T)

        # clustering
        out =self.linear2(out)  
        out = self.softmax(out)   # out.shape = (b,v,k)
        
        #gcn  LXcw
        c = out
        out  = F.relu(self.gc1(x,graph,out))
        out = self.dropout(out)
        return out , c 

In [6]:
class ScaledDotProductAttention(nn.Module):
    ''' Scaled Dot-Product Attention '''

    def __init__(self, temperature, attn_dropout=0.1):
        super().__init__()
        self.temperature = temperature
        self.dropout = nn.Dropout(attn_dropout)
        self.softmax = nn.Softmax(dim=2)

    def forward(self, q, k, v, mask=None):

        attn = torch.bmm(q, k.transpose(1, 2))
        attn = attn / self.temperature

        if mask is not None:
            attn = attn.masked_fill(mask, -np.inf)  # fill the place with -inf where mask tensor is 1.

        attn = self.softmax(attn)
        attn = self.dropout(attn)
        output = torch.bmm(attn, v)

        return output, attn

In [7]:
''' Define the sublayers in encoder/decoder layer '''
class MultiHeadAttention(nn.Module):
    ''' Multi-Head Attention module '''

    def __init__(self, n_head, d_model_in, d_model_out, d_k, d_v, dropout=0.1):
        super().__init__()
        assert d_model_out == n_head * d_k
        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v
            
        self.w_qs = nn.Linear(d_model_out, n_head * d_k)
        self.w_ks = nn.Linear(d_model_in, n_head * d_k)
        self.w_vs = nn.Linear(d_model_in, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model_in + d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model_in + d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model_in + d_v)))

        self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5))
        self.layer_norm = nn.LayerNorm(d_model_out)

        self.fc = nn.Linear(n_head * d_v, d_model_out)
        nn.init.xavier_normal_(self.fc.weight)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):

        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        residual = q
        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv
        
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x ..
        output, attn = self.attention(q, k, v, mask=mask)

        output = output.view(n_head, sz_b, len_q, d_v)
        output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv)

        output = self.dropout(self.fc(output))
        output = self.layer_norm(output + residual)

        return output, attn


class PositionwiseFeedForward(nn.Module):
    ''' A two-feed-forward-layer module '''

    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise
        self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise
        self.layer_norm = nn.LayerNorm(d_in)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        residual = x
        output = x.transpose(1, 2)
        output = self.w_2(F.relu(self.w_1(output)))
        output = output.transpose(1, 2)
        output = self.dropout(output)
        output = self.layer_norm(output + residual)
        return output



In [8]:
''' Define the Layers '''
class EncoderLayer(nn.Module):
    ''' Compose with two layers '''

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input_qk, enc_input_v, slf_attn_mask=None):
        enc_output, enc_slf_attn = self.slf_attn(enc_input_qk, enc_input_qk, enc_input_v, mask=slf_attn_mask)
        enc_output = self.pos_ffn(enc_output)
        return enc_output, enc_slf_attn


class DecoderLayer(nn.Module):
    ''' Compose with three layers '''

    def __init__(self, d_model_in, d_model_out, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.slf_attn = MultiHeadAttention(n_head, d_model_out, d_model_out, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model_in, d_model_out, d_k, d_v, dropout=dropout)
        self.pos_ffn = PositionwiseFeedForward(d_model_out, d_inner, dropout=dropout)

    def forward(self, dec_input_qk, dec_input_v, enc_output, slf_attn_mask=None):
        dec_output, dec_slf_attn = self.slf_attn(dec_input_qk, dec_input_qk, dec_input_v, mask=slf_attn_mask)
        dec_output, dec_enc_attn = self.enc_attn(dec_output, enc_output, enc_output, mask=None)
        dec_output = self.pos_ffn(dec_output)
        return dec_output, dec_slf_attn, dec_enc_attn




In [9]:
'''A wrapper class for optimizer '''
import numpy as np

class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(self, optimizer, d_model, n_warmup_steps):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(d_model, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr




In [10]:
''' Define the Transformer model '''

def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()[:-1]
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask

def cal_weekly(n_position):
    return np.array([(2 * np.pi / 7) * (i / 48) for i in range(n_position)])

def cal_daily(n_position):
    return np.array([2 * np.pi * (i / 48) for i in range(n_position)])

def get_qk_encoding_table(n_position, d_hid):

    def cal_closeness(t, f):
        closeness_vec = np.array([np.exp(-(i / 48) ** 2) for i in range(t)])
        return closeness_vec.reshape(-1, 1).repeat(f, axis=-1)

    def cal_compensation(t, f):
        compensation_vec = np.array([np.exp(-((i / 48) % 7 + 3.5) ** 2) for i in range(t)])
        return compensation_vec.reshape(-1, 1).repeat(f, axis=-1)

    weekly_table = cal_weekly(n_position).reshape(-1, 1).repeat(d_hid, axis=-1)
    daily_table = cal_daily(n_position).reshape(-1, 1).repeat(d_hid, axis=-1)
    weekly_table[:, 0::2] = np.sin(weekly_table[:, 0::2])  # dim 2i
    weekly_table[:, 1::2] = np.cos(weekly_table[:, 1::2])  # dim 2i+1
    daily_table[:, 0::2] = np.sin(daily_table[:, 0::2])  # dim 2i
    daily_table[:, 1::2] = np.cos(daily_table[:, 1::2])  # dim 2i+1

    return torch.FloatTensor(np.array([weekly_table, daily_table, 
                                       cal_closeness(n_position, d_hid), 
                                       cal_compensation(n_position, d_hid)]).transpose(1,2,0).reshape(-1, 4))


def get_v_encoding_table(n_position, d_hid):

    weekly_table = cal_weekly(n_position).reshape(-1, 1).repeat(d_hid, axis=-1)
    daily_table = cal_daily(n_position).reshape(-1, 1).repeat(d_hid, axis=-1)
    weekly_table[:, 0::2] = np.sin(weekly_table[:, 0::2])  # dim 2i
    weekly_table[:, 1::2] = np.cos(weekly_table[:, 1::2])  # dim 2i+1
    daily_table[:, 0::2] = np.sin(daily_table[:, 0::2])  # dim 2i
    daily_table[:, 1::2] = np.cos(daily_table[:, 1::2])  # dim 2i+1

    closeness_vec = np.array([-2 * (i / 48) for i in range(n_position)])
    closeness_table = closeness_vec.reshape(-1, 1).repeat(d_hid, axis=-1)

    compensation_vec = np.array([-2 * ((i / 48) % 7 + 3.5) for i in range(n_position)])
    compensation_table = compensation_vec.reshape(-1, 1).repeat(d_hid, axis=-1)

    # if padding_idx is not None:
    #     # zero vector for padding dimension
    #     sinusoid_table[padding_idx] = 0.
    return torch.FloatTensor(np.array([weekly_table, daily_table, 
                                       closeness_table, compensation_table]).transpose(1,2,0).reshape(-1, 4))


class Encoder(nn.Module):
    ''' A encoder model with self attention mechanism. '''

    def __init__(self, n_input, n_output, v, T, k_cluster, 
                 n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1):

        super().__init__()
        self.cluster_block = ClusterBlock(n_input, n_output, v, T, k_cluster)
        self.n_output = n_output
        self.position_enc = nn.Linear(4, 1, bias=False)
        self.d_model = d_model
        self.layer_stack = nn.ModuleList([
            EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])

    def forward(self, src_seq, src_pos, graph, return_attns=False):
        sz_b, v, t, f = src_seq.size()
        enc_slf_attn_list = []
        cluster_mat = []
        # -- Forward
        src_seq, cluster = self.cluster_block(src_seq, graph)
        cluster_mat += [cluster]
        src_seq = src_seq.permute(0, 2, 1, 3).contiguous().view(sz_b, t, -1)  # (b, t, v*f2)
        pe_qk = self.position_enc(get_qk_encoding_table(src_pos, self.d_model).to(src_seq.device)).view(src_pos, -1)
        pe_v = self.position_enc(get_v_encoding_table(src_pos, self.d_model).to(src_seq.device)).view(src_pos, -1)
        enc_output_qk = src_seq + pe_qk
        enc_output_v = src_seq + pe_v
        enc_output, enc_slf_attn = self.layer_stack[0](enc_output_qk, enc_output_v)
        
        if return_attns:
            enc_slf_attn_list += [enc_slf_attn]

        for enc_layer in self.layer_stack[1:]:
            enc_output = enc_output.view(sz_b, t, v, self.n_output).permute(0, 2, 1, 3)  # (b, v, t, f2)
            enc_output, cluster = self.cluster_block(enc_output, graph)
            cluster_mat += [cluster]
            enc_output = enc_output.permute(0, 2, 1, 3).contiguous().view(sz_b, t, -1)
            enc_output, enc_slf_attn = enc_layer(enc_output, enc_output)

            if return_attns:
                enc_slf_attn_list += [enc_slf_attn]

        if return_attns:
            return enc_output, enc_slf_attn_list, cluser_mat
        return enc_output, cluster_mat


class Decoder(nn.Module):
    ''' A decoder model with self attention mechanism. '''

    def __init__(self, n_layers, n_head, d_k, d_v, d_model_in, d_model_out, d_inner, dropout=0.1):

        super().__init__()
        self.position_enc = nn.Linear(4, 1, bias=False)
        self.d_model_in = d_model_in
        self.d_model_out = d_model_out

        self.layer_stack = nn.ModuleList([
            DecoderLayer(d_model_in, d_model_out, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)])

    def forward(self, tgt_seq, tgt_pos, enc_output, return_attns=False):

        dec_slf_attn_list, dec_enc_attn_list = [], []
        slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
#         _, t, f = tgt_seq.shape
        # -- Forward
        pe_qk = self.position_enc(get_qk_encoding_table(tgt_pos, self.d_model_out).to(tgt_seq.device)).view(tgt_pos, -1)
        pe_v = self.position_enc(get_v_encoding_table(tgt_pos, self.d_model_out).to(tgt_seq.device)).view(tgt_pos, -1)
        dec_output_qk = tgt_seq + pe_qk
        dec_output_v = tgt_seq + pe_v
        dec_output, dec_slf_attn, dec_enc_attn = self.layer_stack[0](dec_output_qk, dec_output_v, 
                                                                     enc_output, slf_attn_mask_subseq)
        if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]
                
        for dec_layer in self.layer_stack[1:]:
            dec_output, dec_slf_attn, dec_enc_attn = dec_layer(dec_output, dec_output, enc_output, slf_attn_mask_subseq)
            if return_attns:
                dec_slf_attn_list += [dec_slf_attn]
                dec_enc_attn_list += [dec_enc_attn]

        if return_attns:
            return dec_output, dec_slf_attn_list, dec_enc_attn_list
        return dec_output,


class Transformer(nn.Module):
    ''' A sequence to sequence model with attention mechanism. '''

    def __init__(self, n_input, n_output, v, T, k_cluster,
                 d_model_enc, d_model_dec, d_inner, n_layers,
                 n_head_enc, n_head_dec, d_k_enc, d_v_enc, d_k_dec, d_v_dec, dropout=0.1):
        '''
        :param d_model_enc: dimension of encoder, always the last dimension of src sequence.
        :param d_model_dec: dimension of decoder, always the last dimension of tgt sequence.
        :param d_inner: the latent dimension of FF sublayer.
        :param n_layers: the number of encoder or decoder layers.
        :param n_head_enc: the number of encoder heads.
        :param n_head_dec: the number of decoder heads.
        :param d_k_enc: equal to d_model_enc / n_head_enc.
        :param d_v_enc: equal to d_model_enc / n_head_enc.
        :param d_k_dec: equal to d_model_dec / n_head_dec.
        :param d_v_dec: equal to d_model_dec / n_head_dec.
        :param dropout: dropout, default is None.
        '''

        super().__init__()
        self.encoder = Encoder(n_input=n_input, n_output=n_output, v=v, T=T, k_cluster=k_cluster,
                               d_model=d_model_enc, d_inner=d_inner,
                               n_layers=n_layers, n_head=n_head_enc, d_k=d_k_enc, d_v=d_v_enc,
                               dropout=dropout)

        self.decoder = Decoder(d_model_in=d_model_enc,d_model_out=d_model_dec, d_inner=d_inner,
                               n_layers=n_layers, n_head=n_head_dec, d_k=d_k_dec, d_v=d_v_dec,
                               dropout=dropout)

    def forward(self, src_seq, tgt_seq, graph):
        src_pos = src_seq.size()[-2]
        tgt_pos = tgt_seq.size()[-2]
        enc_output, cluster, *_ = self.encoder(src_seq, src_pos, graph)
        dec_output, *_ = self.decoder(tgt_seq, tgt_pos, enc_output)
        return dec_output.view(-1, tgt_pos, dec_output.size(2)), cluster

In [None]:
data = pickle.load(open(''))

In [10]:
data = pickle.load(open('/nfs/project/cache_rec/data/19data/bj19_demand.pkl', 'rb'), encoding='iso-8895-1')
data = np.array(data['data'])

In [11]:
euc_graph = pickle.load(open('/nfs/project/cache_rec/graph/bj_L_euc.pkl', 'rb'), encoding='iso-8895-1')
graph = torch.Tensor(euc_graph[0]).to(device)

In [12]:
data[0]

array([[0., 0., 0., ..., 6., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [1., 0., 0., ..., 4., 1., 2.],
       [0., 4., 0., ..., 8., 3., 1.],
       [1., 5., 1., ..., 0., 0., 1.]])

In [13]:
# Hyper-parameters

# GCN
b ,v = data.shape[0], data.shape[1] * data.shape[2]
n_block = 2
n_output = n_input = 1
k_cluster = 5

#Transformer
d_model_enc = v * n_output
d_model_dec = v 
d_inner = 200
n_layers = 3
n_head_enc = 9
n_head_dec = 9
d_k_enc = d_v_enc = 144
d_k_dec = d_v_dec = 144 

learning_rate = 0.001
num_epochs = 50

batch_size = 100
x_length = 2*48
y_length = 1

In [None]:
# model = ClusterBlock(n_input, n_output, v ,T, k_cluster).to(device)

In [15]:
# Build model.
model = Transformer(n_input = n_input, 
                n_output = n_output, 
                v = v, 
                T = x_length, 
                k_cluster = k_cluster, 
                d_model_enc = d_model_enc, 
                d_model_dec = d_model_dec, 
                d_inner = d_inner, 
                n_layers = n_layers, 
                n_head_enc = n_head_enc, 
                n_head_dec = n_head_dec, 
                d_k_enc = d_k_enc, 
                d_v_enc = d_v_enc,
                d_k_dec = d_k_dec, 
                d_v_dec = d_v_dec).to(device)

In [16]:
class linearNormalizer(object):
    def __init__(self):
        pass
    
    def fit(self,X):
        self._min = X.min()
        self._max = X.max()
#         print(('min:', self._min, 'max:', self._max))
        
    def transform(self,X,max_limit=500):
        X[X>max_limit]=max_limit
        X = (X-self._min)/(max_limit-self._min)
        return X
    
    def fit_transform(self,X):
        self.fit(X)
        return self.transform(X)

    def inverse_transform(self,X):
        X = X*(self._max-self._min)+self._min
        return X
    
    def real_loss(self,loss):
        X = X*(self._max-self._min)

In [17]:
class DataLoader():
    def __init__(self, data):
        self.start = 0
        self.data = np.array(data)

    def load_batch(self, lx, ly, batch_size):
        if self.start == 0:
            self.start += lx
        end = self.start + batch_size
        if end > len(self.data) - ly + 1:
            end = len(self.data) - ly + 1
        x = []; y = []
        for i in range(self.start, end):
            x.append(self.data[i - lx:i].reshape(lx, -1))
            y.append(self.data[i: i + ly].reshape(ly, -1))
#         print(np.array(x).shape, np.array(y).shape, self.start)
        x = np.array(x).swapaxes(1, 2)[:, :, :, np.newaxis]
        y = np.array(y)
        self.start += batch_size
        return x, y

In [18]:
train_loader = DataLoader(data[:int(b*0.8)]) 

In [19]:
batch_size = 100
step = (b - v- 1) // 2 + 1

In [20]:
# Loss Function and optimizer
import itertools

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

criterion = nn.SmoothL1Loss()

def rmse(yhat,y):
    return torch.sqrt(torch.mean((yhat-y)**2))

def matrix_distance(A,B):
    D = A - B
    return torch.mean(D**2) 
def cluster_distance(m_list):
    p = list(itertools.combinations(m_list,2))
    res = []
    for item in p:
        res_item = matrix_distance(item[0], item[1])
        res.append(res_item)
    res = torch.Tensor(res) 
    return torch.mean(res)

In [21]:
euc_graph = pickle.load(open('/nfs/project/cache_rec/graph/bj_L_euc.pkl', 'rb'), encoding='iso-8895-1')
graph = torch.Tensor(euc_graph[0]).to(device)

In [22]:
for epoch in range(num_epochs):  
    for i in range(step):
        # Load data and transfer to GPU
        inputs, labels = train_loader.load_batch(x_length, y_length, batch_size)
        inputs = torch.Tensor(inputs).to(device)
        labels = torch.Tensor(labels).to(device)
        print("inputs shape:", inputs.shape )
        print('labels shape:', labels.shape)
        #forward pass 
        outputs, cluster = model(inputs, labels , graph)
        print('output shape', outputs.shape)
        
        # data loss  & cluster loss 
        a = criterion(outputs, labels) 
        b = cluster_distance(cluster)
        loss = a +b 
        
        print("Epoch [{}/{}] , train_loss: {:.6f} ..".format(epoch+1, num_epochs, loss.item()))
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

inputs shape: torch.Size([100, 1296, 96, 1])
labels shape: torch.Size([100, 1, 1296])
output shape torch.Size([100, 1, 1296])
Epoch [1/50] , train_loss: 15.980808 ..
inputs shape: torch.Size([100, 1296, 96, 1])
labels shape: torch.Size([100, 1, 1296])


RuntimeError: CUDA out of memory. Tried to allocate 47.50 MiB (GPU 0; 8.00 GiB total capacity; 5.81 GiB already allocated; 21.86 MiB free; 125.59 MiB cached)