In [2]:
#load LFP data
import numpy as np
import gc
stim_num = [110000, 110101, 110105, 110106, 110107, 110109, 110110, 110111, 110506, 110511, 111105, 111109, 111201, 111299, 111301, 111302, 111303, 111304, 111305, 111306, 111307, 111308]
sub_sessions = [("sub-619296", "ses-1187930705"), ("sub-620333", "ses-1188137866"), ("sub-620334", "ses-1189887297"), ("sub-625545", "ses-1182865981"), ("sub-625554", "ses-1181330601"), ("sub-625555", "ses-1183070926"), ("sub-630506", "ses-1192952695"), ("sub-631510", "ses-1196157974"), ("sub-631570", "ses-1194857009"), ("sub-633229", "ses-1199247593"), ("sub-637484", "ses-1208667752")]
session_num = 3
stim_windows = [] # (trial, time, channel)
probe_num = 2
stim_windows2 = []
probe_num2 = 3
for frame_num in range(len(stim_num)) : 
    stim_windows.append(np.load(f'../material/LFP_npy_data/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}/probe{probe_num}_frame{frame_num}.npy'))
    stim_windows2.append(np.load(f'../material/LFP_npy_data/{sub_sessions[session_num][0]}_{sub_sessions[session_num][1]}/probe{probe_num2}_frame{frame_num}.npy'))

In [3]:
#data
import random
frame_stimtype = [(3, 'IC1'), (7, 'IC2'), (8, 'IRE1'), (9, 'IRE2')]

data = [] # (trial #, channel, time)
labels = []

for i, (frame_num, stimtype) in enumerate(frame_stimtype) :
    trial_nums = np.arange(stim_windows[frame_num].shape[0])
    random.shuffle(trial_nums)
    for j, trial_num in enumerate(trial_nums) :
        if j == 50 : break
        temp_data = np.concatenate((stim_windows[frame_num][trial_num,:,:], stim_windows2[frame_num][trial_num,:,:]), axis = 1)
        data.append(temp_data)
        labels.append(i)
data = np.array(data)
labels = np.array(labels)
del stim_windows
del stim_windows2
gc.collect()

231

tensor([4, 7, 4, 5, 1, 4, 2, 9, 5, 0, 5, 6, 3, 3, 7, 8, 5, 6, 1, 9, 6, 3, 9, 5,
        3, 9, 3, 7, 0, 9, 0, 4, 7, 4, 0, 7, 5, 0, 6, 1, 2, 9, 5, 4, 6, 4, 0, 0,
        2, 5, 8, 9, 4, 0, 0, 0, 7, 3, 3, 2, 9, 0, 8, 9])

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import KFold
from sklearn.metrics import confusion_matrix
import numpy as np
import seaborn as sn 
import pandas as pd
import matplotlib.pyplot as plt

if __name__ == '__main__' :
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(device)

class CSDDataset(Dataset):
    def __init__(self, data, labels):
        self.data = torch.tensor(data, dtype=torch.float32).unsqueeze(1).to(device)  # (200, 400, 84) -> (200, 1, 400, 84)
        self.labels = torch.tensor(labels, dtype=torch.long).to(device)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

class Generator(nn.Module):
    def __init__(self, channel_num, time_len, class_num, z_size = 100):
        super().__init__()

        self.z_size = z_size
        self.channel_num = channel_num
        self.time_len = time_len
        self.class_num = class_num

        self.label_emb = nn.Embedding(self.class_num, self.class_num)

        self.linear1 = nn.Linear(z_size+class_num, 256)
        self.linear2 = nn.Linear(256, 512)
        self.linear3 = nn.Linear(512, 1024)
        self.linear4 = nn.Linear(1024, channel_num * time_len)
        self.relu = nn.ReLU()
        self.tanh = nn.Tanh()

    def forward(self, z, labels):
        x = torch.cat((self.label_emb(labels), z), -1)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.linear4(x)
        x = self.tanh(x) #sigmoid보다 더 또렷한 이미지를 만든다고 함
        x = x.view(x.size(0), 1, self.channel_num, self.time_len)
        return x

class Discriminator(nn.Module) :
    def __init__(self, channel_num, time_len, class_num):
        super().__init__()

        self.channel_num = channel_num
        self.time_len = time_len
        self.class_num = class_num

        self.label_emb = nn.Embedding(self.class_num, self.class_num)
        
        self.linear1 = nn.Linear(self.channel_num * self.time_len + self.class_num, 1024)
        self.linear2 = nn.Linear(1024, 512)
        self.linear3 = nn.Linear(512, 256)
        self.linear4 = nn.Linear(256, 1)
        self.leaky_relu = nn.LeakyReLU(0.2) #보통 generator에는 relu, discrimiator에는 leaky relu(gradient의 소실 방지)를 사용하는데 이유는 잘 모르겠음
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, labels):
        x = torch.cat((x.view(x.size(0),-1),self.label_emb(labels)),-1)
        x = self.leaky_relu(self.linear1(x))
        x = self.leaky_relu(self.linear2(x))
        x = self.leaky_relu(self.linear3(x))
        x = self.linear4(x)
        x = self.sigmoid(x)
        return x

dataset = CSDDataset(data, labels)

train_loader = DataLoader(dataset, batch_size = 10, shuffle = True)

learning_rate = 0.001
num_epochs = 100

generator = Generator(data.shape[2], data.shape[1], len(frame_stimtype)).to(device)
discrimiator = Discriminator(data.shape[2], data.shape[1], len(frame_stimtype)).to(device)

criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)

loss_history={'gen':[],
              'dis':[]}

# Train
batch_count = 0
start_time = time.time()
discrimiator.train()
generator.train()

for epoch in range(num_epochs):
    for xb, yb in train_loader:
        batch_size = xb.shape[0]
        
        yb_real = torch.Tensor(batch_size, 1).fill_(1.0).to(device) # real_label
        yb_fake = torch.Tensor(batch_size, 1).fill_(0.0).to(device) # fake_label
        
        # Genetator
        generator.zero_grad()
        z = torch.randn(batch_size,100).to(device) # 노이즈 생성
        gen_label = torch.randint(0,4,(batch_size,)).to(device) # label 생성

         # 가짜 이미지 생성
        out_gen = generator(z, gen_label)

        # 가짜 이미지 판별
        out_dis = discrimiator(out_gen, gen_label)

        loss_gen = criterion(out_dis, yb_real)
        loss_gen.backward()
        g_optimizer.step()

        # Discriminator
        discrimiator.zero_grad()
        
        # 진짜 이미지 판별
        out_dis = discrimiator(xb, yb)
        loss_real = criterion(out_dis, yb_real)

        # 가짜 이미지 판별
        out_dis = discrimiator(out_gen.detach(),gen_label)
        loss_fake = criterion(out_dis,yb_fake)

        loss_dis = (loss_real + loss_fake) / 2
        loss_dis.backward()
        d_optimizer.step()

        loss_history['gen'].append(loss_gen.item())
        loss_history['dis'].append(loss_dis.item())

        batch_count += 1
        if batch_count % 1000 == 0:
            print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, loss_gen.item(), loss_dis.item(), (time.time()-start_time)/60))

SyntaxError: incomplete input (2920352767.py, line 38)

In [31]:
# plot loss history
plt.figure(figsize=(10,5))
plt.title('Loss Progress')
plt.plot(loss_history['gen'], label='Gen. Loss')
plt.plot(loss_history['dis'], label='Dis. Loss')
plt.xlabel('batch count')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
#가중치 저장
path2models = './models/'
os.makedirs(path2models, exist_ok=True)
path2weights_gen = os.path.join(path2models, 'weights_gen.pt')
path2weights_dis = os.path.join(path2models, 'weights_dis.pt')

torch.save(generator.state_dict(), path2weights_gen)
torch.save(discriminator.state_dict(), path2weights_dis)

In [None]:
# 가중치 불러오기
weights = torch.load(path2weights_gen)
generator.load_state_dict(weights)

# evalutaion mode
generator.eval()

# fake image 생성
with torch.no_grad():
    fig = plt.figure(figsize=(8,8))
    cols, rows = 4, 4 # row와 col 갯수
    for i in range(rows * cols):
        fixed_noise = torch.randn(16, 100, device=device)
        label = torch.randint(0,4,(16,), device=device)
        img_fake = generator(fixed_noise, label).detach().cpu()
        fig.add_subplot(rows, cols, i+1)
        plt.title(label[i].item())
        plt.axis('off')
        plt.imshow(img_fake[i].squeeze(), cmap='gray')
plt.show()