# Crypto Price Prediction with Social & News Sentiment (TFT)
Ready‑to‑run research notebook.

* **Social sentiment:** LunarCrush hourly `average_sentiment` (Twitter + Reddit + other)
* **News sentiment:** FinBERT + CryptoBERT ensemble on NewsAPI headlines
* **Model:** Temporal Fusion Transformer trained under 4 scenarios  
  1. Price only  
  2. Price + Social  
  3. Price + News  
  4. Price + Both  
* Saves all figures/CSV into `results/` for direct paper use


In [None]:
!pip install -q protobuf==3.20.3 \
              ccxt requests pandas numpy scikit-learn matplotlib \
              pytorch-lightning pytorch-forecasting shap tqdm torch transformers



[notice] A new release of pip is available: 25.0.1 -> 25.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [16]:

import os, math, json, time, requests, datetime as dt
from pathlib import Path
import pandas as pd, numpy as np, matplotlib.pyplot as plt
import ccxt, torch, pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer
from pytorch_forecasting.metrics import RMSE
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_absolute_error, mean_squared_error
from transformers import BertTokenizer, BertForSequenceClassification, pipeline
from functools import wraps

pl.seed_everything(42)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE


Seed set to 42


device(type='cpu')

In [None]:

# --- parameters ---
START_DATE   = '2022-07-01'
SYMBOL       = 'BTC/USDT'
TIMEFRAME    = '1h'

# !! Insert your keys here
LUNAR_KEY    = 'YOUR_LUNAR_KEY'
NEWS_API_KEY = 'YOUR_NEWS_API_KEY'

OUTPUT_DIR = Path('results')
OUTPUT_DIR.mkdir(exist_ok=True)


In [18]:

def retry(excs, tries=3, delay=2):
    def deco(fn):
        @wraps(fn)
        def wrap(*a, **kw):
            m = tries
            while m > 1:
                try:
                    return fn(*a, **kw)
                except excs as e:
                    print(f'Retrying after error: {e}')
                    time.sleep(delay)
                    m -= 1
            return fn(*a, **kw)
        return wrap
    return deco


In [19]:

@retry(Exception, 3)
def fetch_ohlcv(symbol, timeframe, since):
    ex = ccxt.binance()
    out = []
    while True:
        o = ex.fetch_ohlcv(symbol, timeframe, since=since, limit=1000)
        if not o:
            break
        out.extend(o)
        since = o[-1][0] + 1
        if len(o) < 1000:
            break
    df = pd.DataFrame(out, columns=['ts','open','high','low','close','volume'])
    df['datetime'] = pd.to_datetime(df.ts, unit='ms', utc=True)
    return df.set_index('datetime').drop(columns='ts')

@retry((requests.exceptions.RequestException, ValueError),3)
def fetch_lunar_sentiment(symbol, start, end, interval='hour'):
    out = []
    cur = start
    while cur < end:
        nxt = min(cur + dt.timedelta(days=30), end)
        url = (f'https://api.lunarcrush.com/v4?data=assets&symbol={symbol}&interval={interval}'
               f'&start={int(cur.timestamp())}&end={int(nxt.timestamp())}&key={LUNAR_KEY}')
        js = requests.get(url, timeout=10).json()
        if 'data' in js and js['data']:
            out.extend(js['data'][0]['timeSeries'])
        cur = nxt
    df = pd.DataFrame(out)
    df['datetime'] = pd.to_datetime(df['time'], unit='s', utc=True)
    return df.set_index('datetime')[['average_sentiment']].rename(columns={'average_sentiment':'soc_sent'})

@retry((requests.exceptions.RequestException, ValueError),3)
def fetch_news(api_key, q, start, end):
    arts, page = [], 1
    while True:
        url=(f'https://newsapi.org/v2/everything?q={q}&from={start}&to={end}&language=en'
             f'&sortBy=publishedAt&pageSize=100&page={page}&apiKey={api_key}')
        js=requests.get(url,timeout=10).json()
        if js.get('status')!='ok' or not js.get('articles'):
            break
        arts += js['articles']
        page += 1
        if page > 5:
            break
    df = pd.DataFrame([{'datetime': pd.to_datetime(a['publishedAt'], utc=True),
                        'text': a['title'] or ''} for a in arts]).set_index('datetime')
    return df


In [23]:
from transformers import AutoTokenizer

tok_f = BertTokenizer.from_pretrained('yiyanghkust/finbert-tone')
mdl_f = BertForSequenceClassification.from_pretrained('yiyanghkust/finbert-tone').to(DEVICE)
tok_c = AutoTokenizer.from_pretrained(
            'ElKulako/cryptobert', use_fast=True, trust_remote_code=False)
mdl_c = BertForSequenceClassification.from_pretrained(
            'ElKulako/cryptobert').to(DEVICE)
pipe_f = pipeline('sentiment-analysis', model=mdl_f, tokenizer=tok_f, device=0 if DEVICE.type=='cuda' else -1)
pipe_c = pipeline('sentiment-analysis', model=mdl_c, tokenizer=tok_c, device=0 if DEVICE.type=='cuda' else -1)

def ensemble_score(txt, w1=0.5, w2=0.5):
    o1 = {d['label']: d['score'] for d in pipe_f(txt, return_all_scores=True)[0]}
    o2 = {d['label']: d['score'] for d in pipe_c(txt, return_all_scores=True)[0]}
    s1 = o1.get('Positive',0) - o1.get('Negative',0)
    s2 = o2.get('Bullish',0)  - o2.get('Bearish',0)
    return float(np.clip(w1*s1 + w2*s2, -1, 1))


You are using a model of type roberta to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at ElKulako/cryptobert and are newly initialized: ['classifier.bias', 'classifier.weight', 'embeddings.LayerNorm.bias', 'embeddings.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'embeddings.token_type_embeddings.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.self.value.weight', 'encoder.la

In [24]:

# Data acquisition & caching
mkt_csv  = OUTPUT_DIR/'market.csv'
soc_csv  = OUTPUT_DIR/'social.csv'
news_csv = OUTPUT_DIR/'news.csv'

if mkt_csv.exists():
    mkt = pd.read_csv(mkt_csv, parse_dates=['datetime'], index_col='datetime')
else:
    since = ccxt.binance().parse8601(f'{START_DATE}T00:00:00Z')
    mkt = fetch_ohlcv(SYMBOL, TIMEFRAME, since)
    mkt['MA'] = mkt['close'].rolling(14).mean()
    d = mkt['close'].diff()
    mkt['RSI'] = 100 - 100 / (1 + (d.clip(lower=0).rolling(14).mean() / (-d.clip(upper=0)).rolling(14).mean()))
    mkt.fillna(method='ffill', inplace=True)
    mkt.to_csv(mkt_csv)

if soc_csv.exists():
    soc = pd.read_csv(soc_csv, parse_dates=['datetime'], index_col='datetime')
else:
    soc = fetch_lunar_sentiment('BTC', dt.datetime.fromisoformat(START_DATE), dt.datetime.utcnow())
    soc.to_csv(soc_csv)

if news_csv.exists():
    nws = pd.read_csv(news_csv, parse_dates=['datetime'], index_col='datetime')
else:
    NEWS_QUERY = ('bitcoin OR btc OR ethereum OR eth OR crypto OR cryptocurrency OR altcoin OR defi')
    raw_news = fetch_news(NEWS_API_KEY, NEWS_QUERY, START_DATE, dt.datetime.utcnow().strftime('%Y-%m-%d'))
    raw_news['clean'] = raw_news['text'].str.lower().str.replace(r'http\S+',' ', regex=True)
    raw_news['sent'] = raw_news['clean'].apply(ensemble_score)
    nws = raw_news[['sent']].rename(columns={'sent':'news_sent'})
    nws.to_csv(news_csv)

# Resample to hourly & merge
mkt  = mkt.resample('1H').ffill()
soc  = soc.resample('1H').mean().fillna(method='ffill')
nws  = nws.resample('1H').mean().fillna(0)
data = mkt.join([soc,nws], how='inner')
data.head()


  mkt.fillna(method='ffill', inplace=True)


Retrying after error: HTTPSConnectionPool(host='api.lunarcrush.com', port=443): Max retries exceeded with url: /v4?data=assets&symbol=BTC&interval=hour&start=1656601200&end=1659193200&key=b1t15t5mabg0h56v3j20np2lbfma0sf0eb4szyxnq (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001C75D8E3C50>: Failed to resolve 'api.lunarcrush.com' ([Errno 11001] getaddrinfo failed)"))
Retrying after error: HTTPSConnectionPool(host='api.lunarcrush.com', port=443): Max retries exceeded with url: /v4?data=assets&symbol=BTC&interval=hour&start=1656601200&end=1659193200&key=b1t15t5mabg0h56v3j20np2lbfma0sf0eb4szyxnq (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001C75D8E2C90>: Failed to resolve 'api.lunarcrush.com' ([Errno 11001] getaddrinfo failed)"))


ConnectionError: HTTPSConnectionPool(host='api.lunarcrush.com', port=443): Max retries exceeded with url: /v4?data=assets&symbol=BTC&interval=hour&start=1656601200&end=1659193200&key=b1t15t5mabg0h56v3j20np2lbfma0sf0eb4szyxnq (Caused by NameResolutionError("<urllib3.connection.HTTPSConnection object at 0x000001C762B3D110>: Failed to resolve 'api.lunarcrush.com' ([Errno 11001] getaddrinfo failed)"))

In [None]:

plt.figure(figsize=(10,4))
plt.plot(data.index, data['close'], label='Close')
ax2 = plt.twinx()
ax2.plot(data.index, data['soc_sent'], 'r--', label='Social')
ax2.plot(data.index, data['news_sent'], 'g:', label='News')
plt.title('BTC Price vs Sentiments')
plt.legend(loc='upper left')
plt.tight_layout()
plt.savefig(OUTPUT_DIR/'price_sentiments.png', dpi=300)
plt.show()


In [None]:

SPLIT1 = dt.datetime(2023,1,1, tzinfo=dt.timezone.utc)
SPLIT2 = dt.datetime(2023,7,1, tzinfo=dt.timezone.utc)
THRESHOLD = 5.0

def run_scenario(sid, base):
    df = base.copy()
    if sid in (2,4): df['soc'] = df['soc_sent']
    if sid in (3,4): df['news'] = df['news_sent']
    df['target'] = df['close'].shift(-1)
    df.dropna(inplace=True)
    feats = [c for c in df.columns if c not in ['open','high','low','close','volume','target']]
    df[feats] = StandardScaler().fit_transform(df[feats])
    df = df.reset_index().rename(columns={'datetime':'time_idx'})
    df['time_idx'] = np.arange(len(df))
    df['group'] = 'crypto'
    ds = TimeSeriesDataSet(df, time_idx='time_idx', target='target', group_ids=['group'],
                           max_encoder_length=24, max_prediction_length=1,
                           time_varying_known_reals=feats,
                           time_varying_unknown_reals=['target'])
    tr, val = ds.split_before(SPLIT1)
    val, _  = val.split_before(SPLIT2)
    tr_dl  = tr.to_dataloader(train=True, batch_size=64)
    val_dl = val.to_dataloader(train=False, batch_size=64)
    tft = TemporalFusionTransformer.from_dataset(ds, learning_rate=1e-2, hidden_size=16,
                                                 attention_head_size=1, dropout=0.1,
                                                 hidden_continuous_size=8, loss=RMSE()).to(DEVICE)
    pl.Trainer(max_epochs=5, accelerator='gpu' if DEVICE.type=='cuda' else 'cpu', devices=1,
               callbacks=[EarlyStopping('val_loss', patience=2)], logger=False,
               enable_checkpointing=False).fit(tft, tr_dl, val_dl)
    y_hat = tft.predict(val_dl).cpu().numpy().flatten()
    y     = np.concatenate([y_[0].numpy() for _, y_ in val_dl])
    mae   = mean_absolute_error(y, y_hat)
    rmse  = math.sqrt(mean_squared_error(y, y_hat))
    vol   = np.mean(np.abs((y_hat - y) / y) >= THRESHOLD) * 100
    try:
        fig_att = tft.plot_interpretation(tft.interpret_output(val_dl, reduction='sum'))
        fig_att.figure.savefig(OUTPUT_DIR/f'attention_s{sid}.png', dpi=300)
        plt.close(fig_att.figure)
    except Exception as e:
        print('Attention plot failed:', e)
    return {'scenario':sid,'MAE':mae,'RMSE':rmse,'VolatilityRate':vol}

perf = pd.DataFrame([run_scenario(s, data) for s in (1,2,3,4)])
perf.to_csv(OUTPUT_DIR/'performance.csv', index=False)
perf


In [None]:

fig, ax = plt.subplots(1,3, figsize=(15,4))
ax[0].bar(perf.scenario, perf.MAE); ax[0].set_title('MAE')
ax[1].bar(perf.scenario, perf.RMSE); ax[1].set_title('RMSE')
ax[2].bar(perf.scenario, perf.VolatilityRate); ax[2].set_title('Volatility %')
for a in ax: a.set_xlabel('Scenario')
plt.tight_layout()
plt.savefig(OUTPUT_DIR/'metrics_bar.png', dpi=300)
plt.show()


In [13]:

def surge_drop(df, k_atr=2):
    d = df.copy()
    d['next'] = d['close'].shift(-1)
    d['chg'] = (d['next'] / d['close'] - 1) * 100
    atr = d['high'].sub(d['low']).rolling(14).mean()
    thr = k_atr * atr
    surge = d[d['chg'] >= thr]
    drop  = d[d['chg'] <= -thr]
    print('Surge', len(surge), 'Drop', len(drop))
    fig, ax = plt.subplots(1,2, figsize=(12,5))
    ax[0].boxplot([surge['soc_sent'], drop['soc_sent']], labels=['Surge','Drop'])
    ax[0].set_title('Social Sentiment')
    ax[1].boxplot([surge['news_sent'], drop['news_sent']], labels=['Surge','Drop'])
    ax[1].set_title('News Sentiment')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR/'boxplot_soc_news.png', dpi=300)
    plt.show()
    plt.figure(figsize=(8,4))
    plt.scatter(d['soc_sent'], d['chg'], alpha=.5, label='Social')
    plt.scatter(d['news_sent'], d['chg'], alpha=.5, label='News')
    plt.legend()
    plt.xlabel('Sentiment')
    plt.ylabel('Δ%')
    plt.title('Sentiment vs Price Move')
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR/'scatter_soc_news.png', dpi=300)
    plt.show()

surge_drop(data)


NameError: name 'data' is not defined