In [2]:
'''
!pip install ccxt 
!pip install transformers 
!pip install lightgbm 
!pip install pandas-ta 
!pip install seaborn
!pip install --upgrade twisted 
'''

'\n!pip install ccxt \n!pip install transformers \n!pip install lightgbm \n!pip install pandas-ta \n!pip install seaborn\n!pip install --upgrade twisted \n'

In [3]:
import numpy as np 
import pandas as pd 
import json
import ccxt 
from tqdm.auto import tqdm
import pandas_ta as ta
import seaborn as sns
from xgboost import XGBClassifier  
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score 
import lightgbm as lgbm  

# import libraries for NN 
import random 
import torch 
from torch import Tensor 
import torch.nn as nn 
import torch.optim as optim 
import torch.nn.functional as F 
from torch.utils.data import Dataset, DataLoader, TensorDataset, RandomSampler, SequentialSampler, IterableDataset  
from tqdm.auto import tqdm  
from transformers import AutoModel, AutoTokenizer, AutoConfig, AdamW, get_linear_schedule_with_warmup
import matplotlib.pyplot as plt
import time
import math

In [11]:
# define model
class PositionalEncoding(nn.Module): 
    def __init__(self, d_model, dropout=0.1, max_len=5000): 
        super(PositionalEncoding, self).__init__() 
        self.dropout = nn.Dropout(p=dropout) 
        pe = torch.zeros(max_len, d_model) 
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 
        pe[:, 0::2] = torch.sin(position * div_term) 
        pe[:, 1::2] = torch.cos(position * div_term) 
        pe = pe.unsqueeze(0).transpose(0,1) 
        self.register_buffer("pe", pe) 
    def forward(self, x): 
        x = x + self.pe[:x.size(0), :] 
        return self.dropout(x) 

class MultiSampleDropout(nn.Module): 
    def __init__(self, max_dropout_rate, num_samples, classifier): 
        super(MultiSampleDropout, self).__init__() 
        self.dropout = nn.Dropout 
        self.classifier = classifier 
        self.max_dropout_rate = max_dropout_rate 
        self.num_samples = num_samples 
    def forward(self, out): 
        return torch.mean(torch.stack([self.classifier(self.dropout(p=self.max_dropout_rate)(out)) for _, rate in enumerate(np.linspace(0, self.max_dropout_rate, self.num_samples))], dim=0), dim=0)

class AttentivePooling(torch.nn.Module): 
    def __init__(self, input_dim): 
        super(AttentivePooling, self).__init__() 
        self.W = nn.Linear(input_dim, 1) 
    def forward(self, x): 
        softmax = F.softmax 
        att_w = softmax(self.W(x).squeeze(-1)).unsqueeze(-1) 
        x = torch.sum(x * att_w, dim=1) 
        return x 

class NeuralCLF(nn.Module): 
    def __init__(self, chart_features, sequence_length, d_model, num_classes, n_heads, num_encoders): 
        super(NeuralCLF, self).__init__() 
        self.chart_features = chart_features 
        self.sequence_length = sequence_length  
        self.d_model = d_model 
        self.num_classes = num_classes  
        self.n_heads = n_heads 
        self.num_encoders = num_encoders 
        self.chart_embedder = nn.Sequential(
            nn.Linear(self.chart_features, d_model//2), 
            nn.ReLU(), 
            nn.Linear(d_model//2, d_model) 
        ) 
        self.pos_encoder = PositionalEncoding(d_model=self.d_model) 
        self.encoder_layers = nn.TransformerEncoderLayer(d_model=self.d_model, nhead=self.n_heads, batch_first=True) 
        self.transformer_encoder = nn.TransformerEncoder(self.encoder_layers, num_layers=self.num_encoders) 
        self.attentive_pooling = AttentivePooling(input_dim=self.d_model)   
        self.fc = nn.Linear(self.d_model, self.num_classes) 
        self.multi_dropout = MultiSampleDropout(0.2, 8, self.fc) 
    def forward(self, x): 
        x = self.chart_embedder(x)
        x = self.pos_encoder(x) 
        x = self.transformer_encoder(x) 
        x = self.attentive_pooling(x) 
        x = self.multi_dropout(x)  
        return x 

In [12]:
class WeightedFocalLoss(nn.Module): 
    def __init__(self, alpha, gamma=2):
        super(WeightedFocalLoss, self).__init__() 
        self.alpha = alpha 
        self.device = torch.device("cuda") 
        self.alpha = self.alpha.to(self.device) 
        self.gamma = gamma 
    def forward(self, inputs, targets): 
        CE_loss = nn.CrossEntropyLoss()(inputs, targets) 
        targets = targets.type(torch.long) 
        at = self.alpha.gather(0, targets.data.view(-1)) 
        pt = torch.exp(-CE_loss) 
        F_loss = at * (1-pt)**self.gamma * CE_loss 
        return F_loss.mean() 

In [13]:
def flat_accuracy(preds, labels): 
    pred_flat = np.argmax(preds, axis=1).flatten() 
    labels_flat = labels.flatten() 
    return np.sum(pred_flat==labels_flat) / len(labels_flat) 

### Train High Model

In [14]:
with open('BTC_USDT-15m-4.json') as f:
    d = json.load(f)
    
chart_df = pd.DataFrame(d)
chart_df = chart_df.rename(columns={0:"timestamp",
                                    1:"open",
                                    2:"high",
                                    3:"low",
                                    4:"close",
                                    5:"volume"})

def process(df): 
    binance = ccxt.binance() 
    dates = df['timestamp'].values 
    timestamp = [] 
    for i in range(len(dates)): 
        date_string = binance.iso8601(int(dates[i])) 
        date_string = date_string[:10] + " " + date_string[11:-5] 
        timestamp.append(date_string) 
    df['datetime'] = timestamp 
    df = df.drop(columns={'timestamp'})
    return df

chart_df = process(chart_df)

minutes = [] 
hours = []
days = [] 
months = [] 
years = [] 
for dt in tqdm(chart_df['datetime']): 
    minute = pd.to_datetime(dt).minute 
    hour = pd.to_datetime(dt).hour 
    day = pd.to_datetime(dt).day 
    month = pd.to_datetime(dt).month 
    year = pd.to_datetime(dt).year  
    minutes.append(minute) 
    hours.append(hour) 
    days.append(day) 
    months.append(month)
    years.append(year) 

chart_df["minute"] = minutes 
chart_df['hour'] = hours
chart_df['day'] = days 
chart_df['month'] = months 
chart_df['year'] = years 

  0%|          | 0/181982 [00:00<?, ?it/s]

In [15]:
p_open, p_high, p_low, p_close, p_volume = [], [], [], [], [] 
p_dt, p_minute, p_hour, p_day, p_month, p_year = [], [], [], [], [], [] 

for i in tqdm(range(chart_df.shape[0] - 16), position=0, leave=True): 
    segment = chart_df.iloc[i:i+16] 
    open_val = segment["open"].values[0] 
    high_val = np.max(segment["high"].values) 
    low_val = np.min(segment["low"].values) 
    close_val = segment["close"].values[-1] 
    volume_val = np.sum(segment["volume"].values) 
    
    dt_val = segment["datetime"].values[0] 
    minute_val = segment["minute"].values[0] 
    hour_val = segment["hour"].values[0] 
    day_val = segment["day"].values[0] 
    month_val = segment["month"].values[0] 
    year_val = segment["year"].values[0] 
    
    p_open.append(open_val) 
    p_high.append(high_val) 
    p_low.append(low_val) 
    p_close.append(close_val) 
    p_volume.append(volume_val) 
    p_dt.append(dt_val) 
    p_minute.append(minute_val) 
    p_hour.append(hour_val) 
    p_day.append(day_val) 
    p_month.append(month_val)
    p_year.append(year_val) 

  0%|          | 0/181966 [00:00<?, ?it/s]

In [16]:
four_chart_df = pd.DataFrame(list(zip(p_open, p_high, p_low, p_close, p_volume, p_dt, p_minute, p_hour, p_day, p_month, p_year)), 
                             columns=["open","high","low","close","volume","datetime","minute","hour","day","month","year"])
                             
four_chart_df.head() 

Unnamed: 0,open,high,low,close,volume,datetime,minute,hour,day,month,year
0,4261.48,4349.99,4261.32,4349.99,82.088865,2017-08-17 04:00:00,0,4,17,8,2017
1,4261.48,4377.85,4261.32,4360.71,80.666221,2017-08-17 04:15:00,15,4,17,8,2017
2,4280.0,4377.85,4267.99,4360.7,71.622199,2017-08-17 04:30:00,30,4,17,8,2017
3,4310.07,4377.85,4287.41,4360.69,49.797909,2017-08-17 04:45:00,45,4,17,8,2017
4,4308.83,4377.85,4287.41,4360.69,35.880663,2017-08-17 05:00:00,0,5,17,8,2017


In [17]:
chart_df.shape, four_chart_df.shape

((181982, 11), (181966, 11))

In [18]:
dfs = [] 

for i in tqdm(range(16), position=0, leave=True): 
    dfs.append(four_chart_df.iloc[i::16]) 

  0%|          | 0/16 [00:00<?, ?it/s]

In [20]:
def preprocess_seq_data(chart_df, target="high", threshold=0.0075): 
    targets = [] 
    openv = chart_df["open"].values 
    close = chart_df["close"].values 
    high = chart_df["high"].values 
    low = chart_df["low"].values  
    volume = chart_df["volume"].values 
    
    if target == "high":
        for i in range(close.shape[0]-1):
            high_vol = (high[i+1] - close[i]) / close[i] 
            if high_vol >= threshold: 
                targets.append(1) 
            else: 
                targets.append(0) 
    elif target == "low": 
        for i in range(close.shape[0]-1):
            low_vol = (low[i+1] - close[i]) / close[i] 
            if low_vol <= -threshold: 
                targets.append(1)
            else:
                targets.append(0) 
    targets.append(None) 
    chart_df["Targets"] = targets 
    
    chart_df.set_index(pd.DatetimeIndex(chart_df["datetime"]), inplace=True)
    chart_df["bop"] = chart_df.ta.bop(lookahead=False) 
    chart_df["ebsw"] = chart_df.ta.ebsw(lookahead=False) 
    chart_df["cmf"] = chart_df.ta.cmf(lookahead=False) 
    chart_df["rsi/100"] = chart_df.ta.rsi(lookahead=False) / 100
    chart_df["high/low"] = chart_df["high"] / chart_df["low"] 
    chart_df["high/open"] = chart_df["high"] / chart_df["open"] 
    chart_df["low/open"] = chart_df["low"] / chart_df["open"] 
    chart_df["close/open"] = chart_df["close"] / chart_df["open"] 
    chart_df["high/close"] = chart_df["high"] / chart_df["close"] 
    chart_df["low/close"] = chart_df["low"] / chart_df["close"] 
    
    ratio_open = [None] 
    ratio_close = [None] 
    ratio_high = [None] 
    ratio_low = [None] 
    ratio_volume = [None] 
    for i in range(1, len(openv)): 
        r_open = openv[i] / openv[i-1] 
        r_close = close[i] / close[i-1] 
        r_high = high[i] / high[i-1] 
        r_low = low[i] / low[i-1] 
        if volume[i-1] == 0: 
            r_vol = 1 
        else:
            r_vol = volume[i] / volume[i-1]
        ratio_open.append(r_open) 
        ratio_close.append(r_close) 
        ratio_high.append(r_high) 
        ratio_low.append(r_low) 
        ratio_volume.append(r_vol) 
    
    chart_df["r_open"] = ratio_open 
    chart_df["r_close"] = ratio_close 
    chart_df["r_high"] = ratio_high 
    chart_df["r_low"] = ratio_low 
    chart_df["r_volume"] = ratio_volume 
    chart_df.dropna(inplace=True) 
    return chart_df

In [22]:
processed_charts = [] 

for df in tqdm(dfs):
    processed_df = preprocess_seq_data(df) 
    processed_charts.append(processed_df) 

  0%|          | 0/16 [00:00<?, ?it/s]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chart_df["Targets"] = targets
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chart_df["bop"] = chart_df.ta.bop(lookahead=False)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chart_df["ebsw"] = chart_df.ta.ebsw(lookahead=False)
A value is trying to be set on a copy of a slice from a DataFrame.
Try 

In [24]:
X = []
Y = [] 

for i in range(len(processed_charts)): 
    X.append(processed_charts[i][["bop", "ebsw", "cmf", "rsi/100", "r_open", "r_close", "r_high", "r_low", "r_volume", "high/low", "high/open", "low/open", "close/open", 
                                  "high/close", "low/close"]])
    Y.append(processed_charts[i][["Targets"]]) 

In [28]:
seq_len = 42

X_train, X_val, X_test = [], [], [] 
Y_train, Y_val, Y_test = [], [], [] 

for i in tqdm(range(len(processed_charts))): 
    cur_X = X[i] 
    cur_Y = Y[i] 
    X_seq, Y_labels = [], [] 
    for j in range(cur_X.shape[0]-seq_len): 
        X_seq.append(cur_X.iloc[j:j+seq_len].values) 
        Y_labels.append(cur_Y.iloc[j+seq_len-1].values[0]) 
    
    train_size = int(0.8 * len(X_seq)) 
    val_size = int(0.1 * len(X_seq)) 
    X_train.extend(X_seq[:train_size]) 
    Y_train.extend(Y_labels[:train_size]) 
    
    X_val.extend(X_seq[train_size:train_size+val_size]) 
    Y_val.extend(Y_labels[train_size:train_size+val_size]) 
    
    X_test.extend(X_seq[train_size+val_size:]) 
    Y_test.extend(Y_labels[train_size+val_size:]) 

  0%|          | 0/16 [00:00<?, ?it/s]

In [29]:
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(Y_train), y=np.array(Y_train)) 
class_weights

array([0.98446782, 1.01603015])

In [30]:
X_train = torch.tensor(X_train).float() 
Y_train = torch.tensor(Y_train, dtype=int) 

X_val = torch.tensor(X_val).float() 
Y_val = torch.tensor(Y_val, dtype=int) 

X_test = torch.tensor(X_test).float() 
Y_test = torch.tensor(Y_test, dtype=int) 


X_train.shape, Y_train.shape, X_val.shape, Y_val.shape, X_test.shape, Y_test.shape

  X_train = torch.tensor(X_train).float()
  Y_train = torch.tensor(Y_train, dtype=int)
  Y_val = torch.tensor(Y_val, dtype=int)
  Y_test = torch.tensor(Y_test, dtype=int)


(torch.Size([144512, 42, 15]),
 torch.Size([144512]),
 torch.Size([18064, 42, 15]),
 torch.Size([18064]),
 torch.Size([18078, 42, 15]),
 torch.Size([18078]))

In [31]:
batch_size = 256 

train_data = TensorDataset(X_train, Y_train) 
train_sampler = RandomSampler(train_data) 
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size) 

val_data = TensorDataset(X_val, Y_val) 
val_sampler = SequentialSampler(val_data) 
val_dataloader = DataLoader(val_data, sampler=val_sampler, batch_size=batch_size) 

test_data = TensorDataset(X_test, Y_test) 
test_sampler = SequentialSampler(test_data) 
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size) 

In [32]:
train_losses, val_losses = [], [] 
train_accuracies, val_accuracies = [], [] 

device = torch.device("cuda") 
model = NeuralCLF(chart_features=X_train.shape[2], sequence_length=X_train.shape[1], d_model=256, num_classes=2, n_heads=8, num_encoders=6) 
model.cuda() 
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8) 
epochs = 20  
total_steps = len(train_dataloader) * epochs 
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=int(0.1*total_steps), num_training_steps=total_steps) 
class_weights = torch.tensor(class_weights, dtype=torch.float) 
loss_func = WeightedFocalLoss(alpha=class_weights) 




In [34]:
# training logic 
model.zero_grad() 
for epoch_i in tqdm(range(epochs), desc="Epochs", position=0, leave=True, total=epochs): 
    train_loss, train_accuracy = 0, 0 
    model.train() 
    with tqdm(train_dataloader, unit="batch") as tepoch: 
        for step, batch in enumerate(tepoch): 
            batch = tuple(t.to(device) for t in batch) 
            b_seqs, b_labels = batch 
            outputs = model(b_seqs) 
            loss = loss_func(outputs, b_labels) 
            train_loss += loss.item() 
            logits_cpu, labels_cpu = outputs.detach().cpu().numpy(), b_labels.detach().cpu().numpy() 
            train_accuracy += flat_accuracy(logits_cpu, labels_cpu) 
            loss.backward() 
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) 
            optimizer.step() 
            scheduler.step() 
            model.zero_grad() 
            tepoch.set_postfix(loss=train_loss / (step+1), accuracy=100.0 * train_accuracy / (step+1)) 
            time.sleep(0.1) 
        avg_train_loss = train_loss / len(train_dataloader) 
        avg_train_accuracy = train_accuracy / len(train_dataloader)  
        train_losses.append(avg_train_loss) 
        train_accuracies.append(avg_train_accuracy) 
        print(f"average train loss : {avg_train_loss}") 
        print(f"average train accuracy : {avg_train_accuracy}") 
    val_loss, val_accuracy = 0, 0 
    model.eval() 
    for step, batch in tqdm(enumerate(val_dataloader), position=0, leave=True, total=len(val_dataloader)): 
        batch = tuple(t.to(device) for t in batch) 
        b_seqs, b_labels = batch 
        with torch.no_grad(): 
            outputs = model(b_seqs) 
        loss = loss_func(outputs, b_labels) 
        val_loss += loss.item()
        logits_cpu, labels_cpu = outputs.detach().cpu().numpy(), b_labels.detach().cpu().numpy() 
        val_accuracy += flat_accuracy(logits_cpu, labels_cpu) 
    avg_val_loss = val_loss / len(val_dataloader) 
    avg_val_accuracy = val_accuracy / len(val_dataloader) 
    val_losses.append(avg_val_loss) 
    val_accuracies.append(avg_val_accuracy) 
    print(f"average val loss: {avg_val_loss}") 
    print(f"average val accuracy: {avg_val_accuracy}") 
    print("saving current checkpoint...") 
    torch.save(model.state_dict(), f"TFNet_CLF_val_acc:{avg_val_accuracy}_val_loss:{avg_val_loss}.pt") 

Epochs:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/565 [00:00<?, ?batch/s]

  att_w = softmax(self.W(x).squeeze(-1)).unsqueeze(-1)


average train loss : 0.18100434644559843
average train accuracy : 0.5023506637168141


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17419067822711568
average val accuracy: 0.5098787167449139
saving best model...


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.17514571639816318
average train accuracy : 0.5034499446902655


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.1717812633010703
average val accuracy: 0.5326682316118936
saving best model...


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.17408974819478734
average train accuracy : 0.5046805862831858


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17190403522739947
average val accuracy: 0.5331939553990611


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.17385032271916886
average train accuracy : 0.5047566371681416


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17235689452836211
average val accuracy: 0.537014622456964


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.17346002211612938
average train accuracy : 0.5081996681415929


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17422684355520865
average val accuracy: 0.4870892018779343


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.17249244493720806
average train accuracy : 0.5198700221238938


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17273475733441365
average val accuracy: 0.522801741001565


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.16056530836936647
average train accuracy : 0.5829300331858407


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17488832687827902
average val accuracy: 0.5485377543035994


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.15199977829393033
average train accuracy : 0.6089670907079646


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.19412992831686854
average val accuracy: 0.547633020344288


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.14976732566029624
average train accuracy : 0.6162887168141593


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.174427394715833
average val accuracy: 0.5558184174491393


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.14810962717881246
average train accuracy : 0.6192408738938053


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17321756629037185
average val accuracy: 0.5560201486697965


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.1472362669695795
average train accuracy : 0.6225525442477876


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.18182640680125062
average val accuracy: 0.5551215277777778


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.14637546752933908
average train accuracy : 0.6238454092920354


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17751193634221252
average val accuracy: 0.5582942097026604


  0%|          | 0/565 [00:00<?, ?batch/s]

average train loss : 0.14549993703850603
average train accuracy : 0.6262928650442477


  0%|          | 0/71 [00:00<?, ?it/s]

average val loss: 0.17531924936133372
average val accuracy: 0.5612896126760564


  0%|          | 0/565 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
# test model performance 
# check performance on test set 
model = NeuralCLF(chart_features=X_train.shape[2], sequence_length=X_train.shape[1], d_model=256, num_classes=2, n_heads=8, num_encoders=6) 
checkpoint = torch.load("") 
best_model.load_state_dict(checkpoint) 
best_model.cuda() 
best_model.eval() 

In [35]:
pred_classes = [] 
model.eval() 
for step, batch in tqdm(enumerate(test_dataloader), desc="Testing", position=0, leave=True, total=len(test_dataloader)): 
    batch = tuple(t.to(device) for t in batch) 
    b_seqs, b_labels = batch 
    with torch.no_grad(): 
        output = model(b_seqs) 
        
    pred_class = torch.argmax(output, dim=1)  
    pred_classes.extend(pred_class) 

Testing:   0%|          | 0/71 [00:00<?, ?it/s]

  att_w = softmax(self.W(x).squeeze(-1)).unsqueeze(-1)


In [36]:
pred_classes_cpu = [] 
for p in pred_classes:
    pred_classes_cpu.append(p.detach().cpu()) 

cnt = 0 
for i in range(len(Y_test)): 
    if Y_test[i] == pred_classes_cpu[i]:  
        cnt += 1 
        
only_zero = 0 
only_one = 0 
for i in range(len(Y_test)): 
    if Y_test[i] == 0: 
        only_zero += 1 
    if Y_test[i] == 1: 
        only_one += 1 

print(f"model accuracy = {cnt/len(Y_test)*100}")  
print(f"model F1 = {f1_score(Y_test, pred_classes_cpu)}") 
print(f"only zero agent = {only_zero/len(Y_test)*100}") 
print(f"only one agent = {only_one/len(Y_test)*100}") 

model accuracy = 58.856068149131545
model F1 = 0.4229635376260667
only zero agent = 56.04602279013166
only one agent = 43.95397720986835


In [37]:
torch.save(model.state_dict(), f"TFNet_CLF_test_acc_58.856.pt") 