In [168]:
from lib.utils import *
from lib.models import *
from lib.ekyn import *
from lib.env import *
from lib.datasets import *
from scipy.signal import resample
from sklearn.preprocessing import RobustScaler

class Windowset(Dataset):
    def __init__(self,X,y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return (self.X[idx],self.y[idx])
def evaluate_utime(dataloader,model,criterion):
    with torch.no_grad():
        y_true = torch.Tensor()
        y_pred = torch.Tensor()
        y_logits = torch.Tensor()
        loss_total = 0
        for (Xi,yi) in dataloader:
            yi = yi.flatten(0,1)
            y_true = torch.cat([y_true,yi.argmax(axis=1).flatten()])
            logits = model(Xi).transpose(1,2).flatten(0,1)

            loss = criterion(logits,yi)
            loss_total += loss.item()
            
            y_logits = torch.cat([y_logits,torch.softmax(logits,dim=1).detach().cpu()])
            y_pred = torch.cat([y_pred,torch.softmax(logits,dim=1).argmax(axis=1).detach().cpu()])
    return y_true,y_pred,y_logits,loss_total/len(dataloader)
def load_dataloader(id='A1-1',condition='Vehicle',shuffle=True):
    X,y = load_eeg_label_pair(id=id,condition=condition)
    fs = 100
    X = torch.from_numpy(resample(X.flatten(),86400*fs)).reshape(-1,fs*10)
    scaler = RobustScaler()
    X = torch.from_numpy(scaler.fit_transform(X.reshape(-1,1)).reshape(-1,fs*10)).float()
    X = X.reshape(-1,10000)
    y = y.reshape(-1,10,3)
    dataloader = DataLoader(Windowset(X,y),batch_size=32,shuffle=shuffle)
    return dataloader
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, n_features, max_pool, max_pool_kernel_size) -> None:
        super().__init__()
        self.max_pool = max_pool
        self.c1 = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=5,stride=1,dilation=2,padding='same')
        self.ln1 = nn.LayerNorm((out_channels,n_features))
        self.r1 = nn.ReLU()
        self.c2 = nn.Conv1d(in_channels=out_channels,out_channels=out_channels,kernel_size=5,stride=1,dilation=2,padding='same')
        self.ln2 = nn.LayerNorm(normalized_shape=(out_channels,n_features))
        self.r2 = nn.ReLU()
        self.mp1 = nn.MaxPool1d(kernel_size=max_pool_kernel_size)
    def forward(self,x):
        x = self.c1(x)
        x = self.ln1(x)
        x = self.r1(x)
        residual = self.c2(x)
        x = self.ln2(x)
        x = self.r2(x)
        if self.max_pool:
            return residual,self.mp1(x)
        return x
class Decoder(nn.Module):
    def __init__(self, scale_factor, in_channels, out_channels, n_features, kernel_size) -> None:
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor,mode='nearest')
        self.c1 = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=1,dilation=2,padding='same')
        self.ln1 = nn.LayerNorm((out_channels,n_features * scale_factor))
        self.r1 = nn.ReLU()
        self.c2 = nn.Conv1d(in_channels=out_channels * 2,out_channels=out_channels,kernel_size=kernel_size,stride=1,dilation=2,padding='same')
        self.ln2 = nn.LayerNorm((out_channels,n_features * scale_factor))
        self.r2 = nn.ReLU()
        self.c3 = nn.Conv1d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=1,dilation=2,padding='same')
        self.ln3 = nn.LayerNorm((out_channels,n_features * scale_factor))
        self.r3 = nn.ReLU()
    def forward(self, x, residual):
        x = self.upsample(x)
        x = self.c1(x)
        x = self.ln1(x)
        x = self.r1(x)
        x = torch.cat([x,residual],dim=1)
        x = self.c2(x)
        x = self.ln2(x)
        x = self.r2(x)
        x = self.c3(x)
        x = self.ln3(x)
        x = self.r3(x)
        return x
class UTIME(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.encoder1 = Encoder(in_channels=1,out_channels=8,n_features=10000,max_pool=True,max_pool_kernel_size=10)
        self.encoder2 = Encoder(in_channels=8,out_channels=8,n_features=1000,max_pool=True,max_pool_kernel_size=8)
        self.encoder3 = Encoder(in_channels=8,out_channels=8,n_features=125,max_pool=False,max_pool_kernel_size=None)

        self.decoder1 = Decoder(scale_factor=8, in_channels=8, out_channels=8, n_features=125, kernel_size=8)
        self.decoder2 = Decoder(scale_factor=10, in_channels=8, out_channels=8, n_features=1000, kernel_size=10)

        self.c = nn.Conv1d(in_channels=8,out_channels=3,kernel_size=1,stride=1)
        self.ap = nn.AvgPool1d(kernel_size=1000)
        self.c1 = nn.Conv1d(in_channels=3,out_channels=3,kernel_size=1,stride=1)
    def forward(self,x,features=False):
        x = x.view(-1,1,10000)
        # print(x.shape)
        a,x = self.encoder1(x)
        # print(f'{x.shape} a: {a.shape}')
        b,x = self.encoder2(x)
        # print(f'{x.shape} b: {b.shape}')
        # print(x.shape)
        x = self.encoder3(x)
        # print(x.shape)
        x = self.decoder1(x,b)
        # print(x.shape)
        x = self.decoder2(x,a)
        # print(x.shape)
        x = self.c(x)
        # print(x.shape)
        x = self.ap(x)
        # print(x.shape)
        x = self.c1(x)
        # print(x.shape)
        x = x.view(-1,3,10)
        return x
def load_eeg_label_pairs_resampled_scaled(ids=['A1-1']):
    old_fs = 500
    fs = 100
    X_train = Tensor()
    y_train = Tensor()
    for id in ids:
        for condition in ['Vehicle', 'PF']:
            Xi,yi = load_eeg_label_pair(id,condition)
            Xi = Xi.flatten()
            Xi = torch.from_numpy(resample(Xi,int(Xi.shape[0]/old_fs)*fs)).reshape(-1,fs*10)
            scaler = RobustScaler()
            Xi = torch.from_numpy(scaler.fit_transform(Xi.reshape(-1,1)).reshape(-1,fs*10)).float()
            Xi = Xi.reshape(-1,10000)
            yi = yi.reshape(-1,10,3)
            X_train = cat([X_train, Xi])
            y_train = cat([y_train, yi])
    return X_train, y_train

In [190]:
trainloader = DataLoader(Windowset(*load_eeg_label_pairs_resampled_scaled(ids=get_ekyn_ids()[:6])),batch_size=32,shuffle=True)
devloader = DataLoader(Windowset(*load_eeg_label_pairs_resampled_scaled(ids=get_ekyn_ids()[-4:])),batch_size=32,shuffle=False)

In [191]:
model = UTIME()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
params = sum([p.flatten().size()[0] for p in list(model.parameters())])
print("Params: ",params)

Params:  890383


In [192]:
train_lossi = []
dev_lossi = []
for epoch in range(10):
    model.train()
    train_total = 0
    for X,y in tqdm(trainloader):
        logits = model(X).transpose(1,2).flatten(0,1)
        loss = criterion(logits,y.flatten(0,1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_total += loss.item()
    train_lossi.append(train_total/len(trainloader))
    model.eval()
    dev_total = 0
    with torch.no_grad():
        for X,y in tqdm(devloader):
            logits = model(X).transpose(1,2).flatten(0,1)
            loss = criterion(logits,y.flatten(0,1))
            dev_total += loss.item()
    dev_lossi.append(dev_total/len(devloader))
    plt.plot(train_lossi)
    plt.plot(dev_lossi)
    plt.savefig('loss.jpg')
    plt.close()

100%|██████████| 324/324 [02:31<00:00,  2.14it/s]
100%|██████████| 216/216 [00:30<00:00,  6.97it/s]
100%|██████████| 324/324 [02:27<00:00,  2.19it/s]
100%|██████████| 216/216 [00:30<00:00,  7.18it/s]
100%|██████████| 324/324 [02:28<00:00,  2.18it/s]
100%|██████████| 216/216 [00:29<00:00,  7.30it/s]
100%|██████████| 324/324 [02:24<00:00,  2.24it/s]
100%|██████████| 216/216 [00:30<00:00,  7.07it/s]
100%|██████████| 324/324 [02:28<00:00,  2.18it/s]
100%|██████████| 216/216 [00:31<00:00,  6.91it/s]
100%|██████████| 324/324 [02:28<00:00,  2.19it/s]
100%|██████████| 216/216 [00:30<00:00,  7.10it/s]
100%|██████████| 324/324 [02:29<00:00,  2.17it/s]
100%|██████████| 216/216 [00:30<00:00,  7.00it/s]
100%|██████████| 324/324 [02:28<00:00,  2.18it/s]
100%|██████████| 216/216 [00:30<00:00,  7.08it/s]
100%|██████████| 324/324 [02:27<00:00,  2.19it/s]
100%|██████████| 216/216 [00:30<00:00,  7.09it/s]
100%|██████████| 324/324 [02:27<00:00,  2.19it/s]
100%|██████████| 216/216 [00:30<00:00,  7.10it/s]


In [None]:
# testloader = load_dataloader(id=,condition='Vehicle',shuffle=False)
y_true,y_pred,y_logits,loss = evaluate_utime(devloader,model,criterion)
cm_grid(y_true,y_pred)
print(loss)

In [None]:
ids = get_ekyn_ids()
for id in ids:
    print(id)
    testloader = load_dataloader(id=id,condition='Vehicle',shuffle=False)
    y_true,y_pred,y_logits,loss = evaluate_utime(testloader,model,criterion)
    cm_grid(y_true,y_pred)
    print(loss)

In [None]:
plt.figure(figsize=(30,4),dpi=200)
plt.stackplot(range(8640),y_logits.T.detach())
plt.plot(y_true)
plt.savefig('out.pdf')

In [None]:
import plotly.express as px
fig = px.line(pd.DataFrame([y_true,y_pred]).T)
fig.show(renderer='browser')

In [None]:
y_pred = torch.softmax(logits,dim=1).argmax(axis=1).flatten()

In [None]:
y_pred

In [None]:
cm_grid(y_true,y_pred)

In [None]:
with torch.no_grad():
    y_true = torch.Tensor()
    y_pred = torch.Tensor()
    y_logits = torch.Tensor()
    y_features = torch.Tensor()
    loss_total = 0
    for (Xi,yi) in testloader:
        y_true = torch.cat([y_true,yi.argmax(axis=1).flatten()])

        logits = model(Xi)
        loss = criterion(logits,yi)
        loss_total += loss.item()
        
        # y_logits = torch.cat([y_logits,torch.softmax(logits,dim=1).detach().cpu()])
        y_pred = torch.cat([y_pred,torch.softmax(logits,dim=1).argmax(axis=1).flatten().detach().cpu()])
        # y_features = torch.cat([y_features,model(Xi,classification=False).detach().cpu()])

In [None]:
cm_grid(y_true,y_pred)