In [18]:
import transformers
from transformers import BertTokenizer, BertForSequenceClassification

import torch
import torch.nn.functional as F

In [19]:
finbert = 'ProsusAI/finbert'

In [20]:
model = BertForSequenceClassification.from_pretrained(finbert)

In [21]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [22]:
tokenizer = BertTokenizer.from_pretrained(finbert)

Downloading (…)solve/main/vocab.txt: 0.00B [00:00, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/252 [00:00<?, ?B/s]

In [23]:
sample_data = ("Given the recent downturn in stocks especially in tech which is likely to persist as yields keep going up, "
       "I thought it would be prudent to share the risks of investing in ARK ETFs, written up very nicely by "
       "[The Bear Cave](https://thebearcave.substack.com/p/special-edition-will-ark-invest-blow). The risks comes "
       "primarily from ARK's illiquid and very large holdings in small cap companies. ARK is forced to sell its "
       "holdings whenever its liquid ETF gets hit with outflows as is especially the case in market downturns. "
       "This could force very painful liquidations at unfavorable prices and the ensuing crash goes into a "
       "positive feedback loop leading into a death spiral enticing even more outflows and predatory shorts.")
sample_data

"Given the recent downturn in stocks especially in tech which is likely to persist as yields keep going up, I thought it would be prudent to share the risks of investing in ARK ETFs, written up very nicely by [The Bear Cave](https://thebearcave.substack.com/p/special-edition-will-ark-invest-blow). The risks comes primarily from ARK's illiquid and very large holdings in small cap companies. ARK is forced to sell its holdings whenever its liquid ETF gets hit with outflows as is especially the case in market downturns. This could force very painful liquidations at unfavorable prices and the ensuing crash goes into a positive feedback loop leading into a death spiral enticing even more outflows and predatory shorts."

In [24]:

tokens = tokenizer.encode_plus(sample_data, max_length=512, truncation=True, padding='max_length', add_special_tokens=True, return_tensors='pt')

tokens

{'input_ids': tensor([[  101,  2445,  1996,  3522,  2091, 22299,  1999, 15768,  2926,  1999,
          6627,  2029,  2003,  3497,  2000, 29486,  2004, 16189,  2562,  2183,
          2039,  1010,  1045,  2245,  2009,  2052,  2022, 10975, 12672,  3372,
          2000,  3745,  1996, 10831,  1997, 19920,  1999, 15745,  3802, 10343,
          1010,  2517,  2039,  2200, 19957,  2011,  1031,  1996,  4562,  5430,
          1033,  1006, 16770,  1024,  1013,  1013,  1996,  4783,  2906, 27454,
          1012,  4942,  9153,  3600,  1012,  4012,  1013,  1052,  1013,  2569,
          1011,  3179,  1011,  2097,  1011, 15745,  1011, 15697,  1011,  6271,
          1007,  1012,  1996, 10831,  3310,  3952,  2013, 15745,  1005,  1055,
          5665, 18515, 21272,  1998,  2200,  2312,  9583,  1999,  2235,  6178,
          3316,  1012, 15745,  2003,  3140,  2000,  5271,  2049,  9583,  7188,
          2049,  6381,  3802,  2546,  4152,  2718,  2007,  2041, 12314,  2015,
          2004,  2003,  2926,  1996,  

In [26]:
pred = model(**tokens)
pred

SequenceClassifierOutput(loss=None, logits=tensor([[-1.8200,  2.4484,  0.0216]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [27]:
logits = pred.logits
logits

tensor([[-1.8200,  2.4484,  0.0216]], grad_fn=<AddmmBackward0>)

In [28]:
probs = F.softmax(logits, dim = -1)
probs

tensor([[0.0127, 0.9072, 0.0801]], grad_fn=<SoftmaxBackward0>)

In [29]:
output = torch.argmax(probs)
output

tensor(1)